Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions lib/Optimizer/Transforms/ApplyOpSpecialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,34 @@ struct ApplyVariants {
/// Map from `func::FuncOp` to the variants to be created.
using ApplyOpAnalysisInfo = DenseMap<Operation *, ApplyVariants>;

/// Check if a function has any func.call operations that take a dynamic
/// !quake.veq<?> argument. If so, we should not specialize (un-relax) veq
/// argument types during constant propagation, as this would cause type
/// mismatches when the specialized function calls inner kernels expecting
/// the dynamic type.
///
/// Alternatives to this conservative approach:
/// 1. Dataflow analysis: trace if a specific argument reaches such a call,
/// allowing specialization of unaffected arguments.
/// 2. Recursive specialization: specialize all callees in the call tree to
/// accept the concrete veq size, propagating type info deeper for better
/// optimization but increasing code size.
static bool hasCallWithDynamicVeq(func::FuncOp func, ModuleOp module) {
auto result = func.walk([&](func::CallOp callOp) {
auto callee = module.lookupSymbol<func::FuncOp>(callOp.getCallee());
if (!callee)
return WalkResult::advance();
for (auto inputTy : callee.getFunctionType().getInputs()) {
if (auto veqTy = dyn_cast<quake::VeqType>(inputTy)) {
if (!veqTy.hasSpecifiedSize())
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
return result.wasInterrupted();
}

/// This analysis scans the IR for `ApplyOp`s to see which ones need to have
/// variants created.
struct ApplyOpAnalysis {
Expand Down Expand Up @@ -99,11 +127,14 @@ struct ApplyOpAnalysis {
LLVM_DEBUG(llvm::dbgs() << "apply has constant arguments.\n");
} else {
if (auto relax = v.getDefiningOp<quake::RelaxSizeOp>()) {
// Also, specialize any relaxed veq types.
v = relax.getInputVec();
updateSignature = true;
LLVM_DEBUG(llvm::dbgs() << "specializing apply veq argument ("
<< v.getType() << ")\n");
// Specialize relaxed veq types, but only if the function has no
// inner calls expecting dynamic !quake.veq<?> types.
if (!hasCallWithDynamicVeq(genericFunc, module)) {
v = relax.getInputVec();
updateSignature = true;
LLVM_DEBUG(llvm::dbgs() << "specializing apply veq argument ("
<< v.getType() << ")\n");
}
}
inputTys.push_back(v.getType());
preservedArgs.push_back(v);
Expand Down
75 changes: 75 additions & 0 deletions python/tests/kernel/test_control_negations.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,81 @@ def control_simple_gate():
e.value)


def test_control_float_list_complex_real_access():
"""
Regression test for a bug in cudaq.control() argument synthesis.

The bug occurs when these three conditions are met:
1. Using cudaq.control() to call a sub-kernel
2. The sub-kernel has BOTH float AND list[complex] parameters
3. The sub-kernel accesses .real on a complex value from the list

Error: 'func.call' op operand type mismatch: expected operand type
'!quake.veq<?>', but provided '!quake.veq<N>'
RuntimeError: Could not successfully apply argument synth.

This pattern is used in the krylov.ipynb notebook.
"""

@cudaq.kernel
def sub_kernel(qubits: cudaq.qview, dt: float, values: list[complex]):
rx(dt * values[0].real, qubits[0])

@cudaq.kernel
def main_kernel(dt: float, values: list[complex]):
ancilla = cudaq.qubit()
qreg = cudaq.qvector(2)
h(ancilla)
cudaq.control(sub_kernel, ancilla, qreg, dt, values)

result = cudaq.sample(main_kernel, 0.1, [0.5 + 0j, 0.25 + 0j])
assert len(result) > 0


def test_control_list_complex_real_access_no_float():
"""
Verify that list[complex] with .real access works when there's no float param.
This is a control test to confirm the bug is specific to the float + list[complex]
combination.
"""

@cudaq.kernel
def sub_kernel(qubits: cudaq.qview, values: list[complex]):
rx(values[0].real, qubits[0])

@cudaq.kernel
def main_kernel(values: list[complex]):
ancilla = cudaq.qubit()
qreg = cudaq.qvector(2)
h(ancilla)
cudaq.control(sub_kernel, ancilla, qreg, values)

# This should work
result = cudaq.sample(main_kernel, [0.5 + 0j, 0.25 + 0j])
assert len(result) > 0


def test_control_float_list_complex_no_real_access():
"""
Verify that float + list[complex] works when .real is not accessed.
This is a control test to confirm the bug requires the .real access.
"""

@cudaq.kernel
def sub_kernel(qubits: cudaq.qview, dt: float, values: list[complex]):
rx(dt, qubits[0])

@cudaq.kernel
def main_kernel(dt: float, values: list[complex]):
ancilla = cudaq.qubit()
qreg = cudaq.qvector(2)
h(ancilla)
cudaq.control(sub_kernel, ancilla, qreg, dt, values)

result = cudaq.sample(main_kernel, 0.1, [0.5 + 0j, 0.25 + 0j])
assert len(result) > 0


# leave for gdb debugging
if __name__ == "__main__":
loc = os.path.abspath(__file__)
Expand Down
71 changes: 71 additions & 0 deletions test/Transforms/apply_ctrl_veq_type.qke
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// ========================================================================== //
// Copyright (c) 2022 - 2026 NVIDIA Corporation & Affiliates. //
// All rights reserved. //
// //
// This source code and the accompanying materials are made available under //
// the terms of the Apache License 2.0 which accompanies this distribution. //
// ========================================================================== //

// Regression test for a bug where veq type specialization caused type mismatches
// when creating control variants of functions that have inner calls to other
// kernels expecting dynamic !quake.veq<?> types.
//
// The bug scenario:
// 1. main_kernel allocates !quake.veq<2> and relaxes it to !quake.veq<?>
// 2. A controlled apply to thunk_kernel with a constant argument triggers
// constant propagation which would un-relax the veq type back to !quake.veq<2>
// 3. The thunk calls inner_kernel which expects !quake.veq<?>
// 4. When the .ctrl variant is created, the inner call has a type mismatch
//
// The fix: Don't specialize (un-relax) veq types when the function has inner
// func.call operations, since those callees still expect the dynamic type.
// An alternative fix would be to recursively specialize all callees in the
// call tree, but that would increase code size.

// RUN: cudaq-opt --apply-op-specialization=constant-prop=1 %s | FileCheck %s

module {
// Inner kernel that expects dynamic veq type
func.func @inner_kernel(%qubits: !quake.veq<?>, %angle: f64) attributes {"cudaq-kernel"} {
%0 = quake.extract_ref %qubits[0] : (!quake.veq<?>) -> !quake.ref
quake.rx (%angle) %0 : (f64, !quake.ref) -> ()
return
}

// Thunk kernel that calls inner_kernel - will be specialized with constant prop
func.func @thunk_kernel(%qubits: !quake.veq<?>, %angle: f64) attributes {"cudaq-kernel"} {
func.call @inner_kernel(%qubits, %angle) : (!quake.veq<?>, f64) -> ()
return
}

// Main kernel that uses controlled apply with a constant angle
func.func @main_kernel() attributes {"cudaq-entrypoint", "cudaq-kernel"} {
%cst = arith.constant 0.1 : f64
%ancilla = quake.alloca !quake.ref
%qreg = quake.alloca !quake.veq<2>
%relaxed = quake.relax_size %qreg : (!quake.veq<2>) -> !quake.veq<?>
quake.h %ancilla : (!quake.ref) -> ()
// Controlled apply with constant - triggers specialization
quake.apply @thunk_kernel [%ancilla] %relaxed, %cst : (!quake.ref, !quake.veq<?>, f64) -> ()
return
}
}

// Verify that the .ctrl variant is created and veq type is NOT specialized
// (remains !quake.veq<?>) because the function has inner func.call operations.

// CHECK-LABEL: func.func private @thunk_kernel.0.ctrl(
// CHECK-SAME: %[[CTRL:.*]]: !quake.veq<?>,
// CHECK-SAME: %[[QUBITS:.*]]: !quake.veq<?>)
// The veq argument keeps its dynamic type, so no relax_size is needed
// CHECK: call @inner_kernel(%[[QUBITS]],
// CHECK-SAME: : (!quake.veq<?>, f64) -> ()
// CHECK: return

// CHECK-LABEL: func.func @main_kernel()
// CHECK: %[[ANCILLA:.*]] = quake.alloca !quake.ref
// CHECK: %[[QREG:.*]] = quake.alloca !quake.veq<2>
// CHECK: %[[RELAXED:.*]] = quake.relax_size %[[QREG]]
// CHECK: quake.h %[[ANCILLA]]
// CHECK: %[[CONCAT:.*]] = quake.concat %[[ANCILLA]]
// CHECK: call @thunk_kernel.0.ctrl(%[[CONCAT]], %[[RELAXED]])
Loading