Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions python/cudaq/kernel/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,40 @@

import ast
import inspect
import importlib
import textwrap
from typing import Optional, Type

from cudaq.mlir._mlir_libs._quakeDialects import cudaq_runtime
from cudaq.mlir.dialects import cc
from .utils import mlirTypeFromAnnotation


class HasReturnNodeVisitor(ast.NodeVisitor):
class FunctionDefVisitor(ast.NodeVisitor):
"""
This visitor will visit the function definition and report
true if that function has a return statement.
This visitor will visit the function definition of `kernel_name` and report
type annotations and whether the function has a return statement.
"""

def __init__(self):
self.hasReturnNode = False
arg_annotations: list[(str, Type)] = []
return_annotation: Optional[Type] = None
has_return_statement: bool = False
found: bool = False

def __init__(self, kernel_name: str):
self.kernel_name: str = kernel_name

def visit_FunctionDef(self, node):
for n in node.body:
if isinstance(n, ast.Return) and n.value != None:
self.hasReturnNode = True
if node.name == self.kernel_name:
self.found = True
self.arg_annotations = [
(arg.arg, arg.annotation) for arg in node.args.args
]
self.return_annotation = node.returns
self.has_return_statement = any(
isinstance(n, ast.Return) and n.value != None
for n in node.body)

def generic_visit(self, node):
if self.found:
# skip traversing the rest of the AST once found
return
super().generic_visit(node)


class FindDepFuncsVisitor(ast.NodeVisitor):
Expand Down
150 changes: 50 additions & 100 deletions python/cudaq/kernel/ast_bridge.py

Large diffs are not rendered by default.

23 changes: 10 additions & 13 deletions python/cudaq/kernel/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,39 +1360,36 @@ def resolve_callable_arg(self, insPt, target):
closure here.
Returns a `CreateLambdaOp` closure.
"""
# Add the target kernel to the current module.
cudaq_runtime.updateModule(self.uniqName, self.module, target.qkeModule)
fulluniq = nvqppPrefix + target.uniqName
fn = recover_func_op(self.module, fulluniq)

# build the closure to capture the lifted `args`
thisPyMod = recover_calling_module()
if target.defModule != thisPyMod:
m = target.defModule
else:
m = None
fulluniq = nvqppPrefix + target.uniqName
fn = recover_func_op(self.module, fulluniq)
funcTy = fn.type
if target.firstLiftedPos:
moduloInTys = funcTy.inputs[:target.firstLiftedPos]
else:
moduloInTys = funcTy.inputs
callableTy = cc.CallableType.get(self.ctx, moduloInTys, funcTy.results)
funcTy = target.signature.get_lifted_type()
callableTy = target.signature.get_callable_type()
with insPt, self.loc:
lamb = cc.CreateLambdaOp(callableTy, loc=self.loc)
lamb.attributes.__setitem__('function_type', TypeAttr.get(funcTy))
initRegion = lamb.initRegion
initBlock = Block.create_at_start(initRegion, moduloInTys)
initBlock = Block.create_at_start(initRegion, target.arg_types())
inner = InsertionPoint(initBlock)
with inner:
vs = []
for ba in initBlock.arguments:
vs.append(ba)
for i, a in enumerate(target.liftedArgs):
v = recover_value_of(a, m)
for var in target.captured_variables():
v = recover_value_of(var.name, m)
if isa_kernel_decorator(v):
# The recursive step
v = self.resolve_callable_arg(inner, v)
else:
argTy = funcTy.inputs[target.firstLiftedPos + i]
v = self.__getMLIRValueFromPythonArg(v, argTy)
v = self.__getMLIRValueFromPythonArg(v, var.type)
vs.append(v)
if funcTy.results:
call = func.CallOp(fn, vs).result
Expand Down
Loading
Loading