Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 23 additions & 51 deletions python/cudaq/kernel/ast_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,13 @@ class keeps track of a MLIR Value stack that is pushed to and popped from
MLIR code.
"""

def __init__(self, **kwargs):
def __init__(self,
knownResultType=None,
*,
uniqueId=None,
kernelModuleName=None,
locationOffset=('', 0),
verbose=False):
"""
The constructor. Initializes the `mlir.Value` stack, the `mlir.Context`,
and the `mlir.Module` that we will be building upon. This class keeps
Expand All @@ -391,38 +397,14 @@ def node_error(msg):

self.symbolTable = PyScopedSymbolTable(error_handler=node_error)
self.valueStack = PyStack(error_handler=node_error)
self.knownResultType = kwargs[
'knownResultType'] if 'knownResultType' in kwargs else None
self.uniqueId = kwargs['uniqueId'] if 'uniqueId' in kwargs else None
self.kernelModuleName = kwargs[
'kernelModuleName'] if 'kernelModuleName' in kwargs else None
if 'existingModule' in kwargs:
self.module = kwargs['existingModule']
self.ctx = self.module.context
self.loc = Location.unknown(context=self.ctx)
else:
self.ctx = getMLIRContext()
self.loc = Location.unknown(context=self.ctx)
self.module = Module.create(self.loc)

# If the driver of this AST bridge instance has indicated that there is
# a return type from analysis on the Python AST, then we want to set the
# known result type so that the FuncOp can have it.
if 'returnTypeIsFromPython' in kwargs and kwargs[
'returnTypeIsFromPython'] and self.knownResultType is not None:
self.knownResultType = mlirTypeFromPyType(self.knownResultType,
self.ctx)

self.capturedVars = {}
self.dependentCaptureVars = {}
self.liftedArgs = []
self.locationOffset = kwargs[
'locationOffset'] if 'locationOffset' in kwargs else ('', 0)
self.disableEntryPointTag = (kwargs['disableEntryPointTag']
if 'disableEntryPointTag' in kwargs else
False)
self.disableNvqppPrefix = kwargs[
'disableNvqppPrefix'] if 'disableNvqppPrefix' in kwargs else False
self.knownResultType = knownResultType
self.uniqueId = uniqueId
self.kernelModuleName = kernelModuleName
self.ctx = getMLIRContext()
self.loc = Location.unknown(context=self.ctx)
self.module = Module.create(self.loc)

self.locationOffset = locationOffset
self.indent_level = 0
self.indent = 4 * " "
self.buildingFunctionBody = False
Expand All @@ -433,7 +415,7 @@ def node_error(msg):
self.controlNegations = []
self.pushPointerValue = False
self.isSubscriptRoot = False
self.verbose = 'verbose' in kwargs and kwargs['verbose']
self.verbose = verbose
self.currentNode = None
self.firstLiftedPos = None

Expand Down Expand Up @@ -1763,10 +1745,7 @@ def visit_FunctionDef(self, node):

# the full function name in MLIR is `__nvqpp__mlirgen__` + the
# function name
if self.disableNvqppPrefix:
fullName = self.name
else:
fullName = nvqppPrefix + self.name
fullName = nvqppPrefix + self.name

# Create the FuncOp
f = func.FuncOp(fullName, (self.argTypes, [] if self.knownResultType
Expand All @@ -1776,9 +1755,10 @@ def visit_FunctionDef(self, node):

# Set this kernel as an entry point if the argument types are
# classical only
areQuantumTypes = [self.isQuantumType(ty) for ty in self.argTypes]
anyQuantumType = any(
self.isQuantumType(ty) for ty in self.argTypes)
f.attributes.__setitem__('cudaq-kernel', UnitAttr.get())
if True not in areQuantumTypes and not self.disableEntryPointTag:
if not anyQuantumType:
f.attributes.__setitem__('cudaq-entrypoint', UnitAttr.get())

# Create the entry block
Expand Down Expand Up @@ -1833,7 +1813,7 @@ def visit_FunctionDef(self, node):
"processing error - unprocessed scope(s) in symbol table",
node)

if True not in areQuantumTypes:
if not anyQuantumType:
attr = DictAttr.get(
{
fullName:
Expand Down Expand Up @@ -5336,13 +5316,9 @@ def visit_Name(self, node):
# is handled elsewhere.
return

# node.id is a non-local symbol. Lift it to a formal argument.
self.dependentCaptureVars[node.id] = value
# If `node.id` is in `liftedArgs`, it should already be in the
# symbol table and processed.
# If `node.id` is already captured, it should be in the symbol table
# and processed.
assert not node.id in self.liftedArgs
if node.id not in self.liftedArgs:
self.liftedArgs.append(node.id)

# Append as a new argument
argTy = mlirTypeFromPyType(type(value), self.ctx, argInstance=value)
Expand Down Expand Up @@ -5391,8 +5367,6 @@ def compile_to_mlir(uniqueId, astModule, **kwargs):
verbose = 'verbose' in kwargs and kwargs['verbose']
returnType = kwargs['returnType'] if 'returnType' in kwargs else None
lineNumberOffset = kwargs['location'] if 'location' in kwargs else ('', 0)
parentVariables = kwargs[
'parentVariables'] if 'parentVariables' in kwargs else {}
preCompile = kwargs['preCompile'] if 'preCompile' in kwargs else False
kernelModuleName = kwargs[
'kernelModuleName'] if 'kernelModuleName' in kwargs else None
Expand All @@ -5401,9 +5375,7 @@ def compile_to_mlir(uniqueId, astModule, **kwargs):
bridge = PyASTBridge(uniqueId=uniqueId,
verbose=verbose,
knownResultType=returnType,
returnTypeIsFromPython=True,
locationOffset=lineNumberOffset,
capturedVariables=parentVariables,
kernelModuleName=kernelModuleName)

ValidateArgumentAnnotations(bridge).visit(astModule)
Expand Down
16 changes: 0 additions & 16 deletions python/cudaq/kernel/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,22 +1402,6 @@ def resolve_callable_arg(self, insPt, target):
cc.ReturnOp([])
return lamb

def c_if(self, measurement, function):
"""
Apply the `function` to the :class:`Kernel` if the provided single-qubit
`measurement` returns the 1-state.

Args:
measurement (:class:`QuakeValue`): The handle to the single qubit
measurement instruction.
function (Callable): The function to conditionally apply to the
:class:`Kernel`.

Raises:
RuntimeError: No longer supported.
"""
emitFatalError(
"`c_if` is no longer supported. Use kernel mode with `run` API.")

def for_loop(self, start, stop, function):
"""
Expand Down
30 changes: 7 additions & 23 deletions python/cudaq/kernel/kernel_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,10 @@ def __init__(self,
verbose=False,
module=None,
kernelName=None,
funcSrc=None,
signature=None,
location=None,
overrideGlobalScopedVars=None,
decorator=None,
fromBuilder=False):

if funcSrc is not None:
emitWarning(
"Passing 'funcSrc' to PyKernelDecorator is deprecated. Pass a string to `function` instead."
)
function = funcSrc
decorator=None):

self.location = location
self.signature = signature
Expand Down Expand Up @@ -169,13 +161,6 @@ def signatureWithCallables(self):
return True
return False

def getKernelType(self):
if self.returnType:
resTys = [self.returnType]
else:
resTys = []
return FunctionType.get(inputs=self.argTypes, results=resTys)

def pre_compile(self):
"""
Compile the Python AST to portable Quake.
Expand All @@ -196,7 +181,7 @@ def pre_compile(self):
id(self),
self.astModule,
verbose=self.verbose,
returnType=self.returnType,
returnType=self.return_type,
location=self.location,
parentVariables=self.globalScopedVars,
preCompile=True,
Expand Down Expand Up @@ -455,7 +440,7 @@ def formal_arity(self):
return self.firstLiftedPos
return len(self.argTypes)

def handle_call_arguments(self, *args, ignoreReturnType=False):
def handle_call_arguments(self, *args):
"""
Resolve all the arguments at the call site for this decorator.
"""
Expand Down Expand Up @@ -492,9 +477,9 @@ def get_none_type(self):
return NoneType.get(self.qkeModule.context)

def handle_call_results(self):
if not self.returnType:
if not self.return_type:
return self.get_none_type()
return mlirTypeFromPyType(self.returnType, self.qkeModule.context)
return mlirTypeFromPyType(self.return_type, self.qkeModule.context)

def launch_args_required(self):
"""
Expand Down Expand Up @@ -643,13 +628,13 @@ def _parse_signature_from_python(self):
emitFatalError('CUDA-Q kernel has return statement '
'but no return type annotation.')

self.returnType = self.signature.get('return', None)
self.return_type = self.signature.get('return', None)

def _parse_signature_from_mlir(self):
funcOp = recover_func_op(self.qkeModule, nvqppPrefix + self.uniqName)
fnTy = FunctionType(TypeAttr(funcOp.attributes['function_type']).value)
self.argTypes = fnTy.inputs
self.returnType = fnTy.results[0] if fnTy.results else None
self.return_type = fnTy.results[0] if fnTy.results else None

def _parse_ast(self):
self.astModule = ast.parse(self.funcSrc)
Expand All @@ -667,7 +652,6 @@ def mk_decorator(builder):
that handles both CUDA-Q kernel object classes more unified.
"""
return PyKernelDecorator(None,
fromBuilder=True,
module=builder.module,
kernelName=builder.uniqName)

Expand Down
4 changes: 2 additions & 2 deletions python/cudaq/runtime/draw.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ============================================================================ #
# Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. #
# Copyright (c) 2022 - 2026 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
Expand All @@ -19,7 +19,7 @@ def _detail_draw(format, decorator, *args):
str(decorator.formal_arity()) + " expected.")
# Must handle arguments exactly like this is a `callsite` to the decorator.
specMod, processedArgs = decorator.handle_call_arguments(*args)
retTy = decorator.returnType
retTy = decorator.return_type
if not retTy:
retTy = decorator.get_none_type()
# Arguments are resolved, so go ahead and do the draw functionality, which
Expand Down
6 changes: 3 additions & 3 deletions python/cudaq/runtime/resource_count.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ============================================================================ #
# Copyright (c) 2025 NVIDIA Corporation & Affiliates. #
# Copyright (c) 2025 - 2026 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
Expand Down Expand Up @@ -35,8 +35,8 @@ def estimate_resources(kernel, *args, **kwargs):
else:
decorator = mk_decorator(kernel)
specMod, processedArgs = decorator.handle_call_arguments(*args)
returnTy = (decorator.returnType
if decorator.returnType else decorator.get_none_type())
returnTy = (decorator.return_type
if decorator.return_type else decorator.get_none_type())
choice = kwargs.get("choice", None)
return cudaq_runtime.estimate_resources_impl(decorator.uniqName, specMod,
returnTy, choice,
Expand Down
8 changes: 4 additions & 4 deletions python/cudaq/runtime/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ def __broadcastSample(kernel,
def _detail_check_conditionals_on_measure(kernel):
has_conditionals_on_measure_result = False
if isa_kernel_decorator(kernel):
if kernel.returnType is not None:
if kernel.return_type is not None:
raise RuntimeError(
f"The `sample` API only supports kernels that return None "
f"(void). Kernel '{kernel.name}' has return type "
f"'{kernel.returnType}'. Consider using `run` for kernels "
f"'{kernel.return_type}'. Consider using `run` for kernels "
f"that return values.")
# Only check for kernels that are compiled, not library-mode kernels (e.g., photonics)
if kernel.qkeModule is not None:
Expand Down Expand Up @@ -245,8 +245,8 @@ def sample_async(decorator,
if (not isinstance(shots_count, int)) or (shots_count < 0):
raise RuntimeError(
"Invalid `shots_count`. Must be a non-negative number.")
if (decorator.returnType and
decorator.returnType != decorator.get_none_type()):
if (decorator.return_type and
decorator.return_type != decorator.get_none_type()):
raise RuntimeError("The `sample_async` API only supports kernels that "
"return None (void). Consider using `run_async` for "
"kernels that return values.")
Expand Down
8 changes: 4 additions & 4 deletions python/cudaq/runtime/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def get_state(kernel, *args):
else:
decorator = mk_decorator(kernel)
specMod, processedArgs = decorator.handle_call_arguments(*args)
returnTy = (decorator.returnType
if decorator.returnType else decorator.get_none_type())
returnTy = (decorator.return_type
if decorator.return_type else decorator.get_none_type())
return cudaq_runtime.get_state_impl(decorator.uniqName, specMod, returnTy,
*processedArgs)

Expand Down Expand Up @@ -73,8 +73,8 @@ def get_state_async(kernel, *args, qpu_id=0):
else:
decorator = mk_decorator(kernel)
specMod, processedArgs = decorator.handle_call_arguments(*args)
returnTy = (decorator.returnType
if decorator.returnType else decorator.get_none_type())
returnTy = (decorator.return_type
if decorator.return_type else decorator.get_none_type())
return cudaq_runtime.get_state_async_impl(decorator.uniqName, specMod,
returnTy, qpu_id, *processedArgs)

Expand Down
6 changes: 3 additions & 3 deletions python/cudaq/runtime/translate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ============================================================================ #
# Copyright (c) 2025 NVIDIA Corporation & Affiliates. #
# Copyright (c) 2025 - 2026 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
Expand Down Expand Up @@ -83,8 +83,8 @@ def bell_pair():
raise RuntimeError(f"Invalid number of arguments passed to translate. "
f"{suppliedArgs} given, {deducedArgs} deduced, and "
f"{launchArgsReq} required.")
retTy = (decorator.returnType
if decorator.returnType else decorator.get_none_type())
retTy = (decorator.return_type
if decorator.return_type else decorator.get_none_type())
# Arguments are resolved. Specialize this kernel and translate to the
# selected transport layer.
return cudaq_runtime.translate_impl(decorator.uniqName, specMod, retTy,
Expand Down
6 changes: 3 additions & 3 deletions python/cudaq/runtime/unitary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ============================================================================ #
# Copyright (c) 2025 NVIDIA Corporation & Affiliates. #
# Copyright (c) 2025 - 2026 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
Expand Down Expand Up @@ -37,7 +37,7 @@ def bell():
else:
decorator = mk_decorator(kernel)
specMod, processedArgs = decorator.handle_call_arguments(*args)
returnTy = (decorator.returnType
if decorator.returnType else decorator.get_none_type())
returnTy = (decorator.return_type
if decorator.return_type else decorator.get_none_type())
return cudaq_runtime.get_unitary_impl(decorator.uniqName, specMod, returnTy,
*processedArgs)