Skip to content

Commit 6531dc3

Browse files
authored
[python] Move kernel signature inference out of MLIR compilation (#3930)
I've broken up this PR into two commits that can be reviewed independently: The first commit contains "clean-ups". It doesn't introduce any new functionality nor should it change how anything is currently working. It just removes function arguments that aren't used, renames a couple of things and should make the code more readable. I'm including this here because it made my actual work much easier. These breaking changes are acceptable as they are within the `python/cudaq/kernel` folder, which is not publicly exposed (discussed with Bettina). The second commit is the real contribution: - The main goal is to parse the kernel signature _before_ running the MLIR compilation. This is required so that we can delay MLIR compilation until the first kernel invocation (in a future PR). - To this end, an AST visitor in `python/cudaq/kernel/analysis.py` is extended to parse the function annotations. This can be run before MLIR compilation - A type `KernelSignature` is introduced to consolidate all kernel type information that was scattered across many attributes in one place. - `KernelSignature` now only stores the kernel types as MLIR types (before, both Python and MLIR types were stored). This makes it easier to handle as there is a single source of truth (and Python types can be recovered using `mlirTypeToPyType`) EDIT: Added a third commit that fixes issue #2895 I've noted that a lot of this Python code uses CamelCase, even though our style guide ([Google style guide](#2895)) recommends snake_case. I'm using snake_case where new names/variables are introduced, but limited the number of existing names that had to be changed. closes #2895 --------- Signed-off-by: Luca Mondada <luca@mondada.net>
1 parent b235c5a commit 6531dc3

File tree

15 files changed

+792
-292
lines changed

15 files changed

+792
-292
lines changed

python/cudaq/kernel/analysis.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,41 @@
88

99
import ast
1010
import inspect
11-
import importlib
1211
import textwrap
12+
from typing import Optional, Type
1313

14-
from cudaq.mlir._mlir_libs._quakeDialects import cudaq_runtime
15-
from cudaq.mlir.dialects import cc
16-
from .utils import mlirTypeFromAnnotation
1714

18-
19-
class HasReturnNodeVisitor(ast.NodeVisitor):
15+
class FunctionDefVisitor(ast.NodeVisitor):
2016
"""
21-
This visitor will visit the function definition and report
22-
true if that function has a return statement.
17+
This visitor will visit the function definition of `kernel_name` and report
18+
type annotations and whether the function has a return statement.
2319
"""
2420

25-
def __init__(self):
26-
self.hasReturnNode = False
21+
arg_annotations: list[(str, Type)]
22+
return_annotation: Optional[Type] = None
23+
has_return_statement: bool = False
24+
found: bool = False
25+
26+
def __init__(self, kernel_name: str):
27+
self.kernel_name: str = kernel_name
28+
self.arg_annotations = []
2729

2830
def visit_FunctionDef(self, node):
29-
for n in node.body:
30-
if isinstance(n, ast.Return) and n.value != None:
31-
self.hasReturnNode = True
31+
if node.name == self.kernel_name:
32+
self.found = True
33+
self.arg_annotations = [
34+
(arg.arg, arg.annotation) for arg in node.args.args
35+
]
36+
self.return_annotation = node.returns
37+
self.has_return_statement = any(
38+
isinstance(n, ast.Return) and n.value != None
39+
for n in node.body)
40+
41+
def generic_visit(self, node):
42+
if self.found:
43+
# skip traversing the rest of the AST once found
44+
return
45+
super().generic_visit(node)
3246

3347

3448
class FindDepFuncsVisitor(ast.NodeVisitor):

python/cudaq/kernel/ast_bridge.py

Lines changed: 50 additions & 100 deletions
Large diffs are not rendered by default.

python/cudaq/kernel/kernel_builder.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,39 +1360,36 @@ def resolve_callable_arg(self, insPt, target):
13601360
closure here.
13611361
Returns a `CreateLambdaOp` closure.
13621362
"""
1363+
# Add the target kernel to the current module.
13631364
cudaq_runtime.updateModule(self.uniqName, self.module, target.qkeModule)
1365+
fulluniq = nvqppPrefix + target.uniqName
1366+
fn = recover_func_op(self.module, fulluniq)
1367+
13641368
# build the closure to capture the lifted `args`
13651369
thisPyMod = recover_calling_module()
13661370
if target.defModule != thisPyMod:
13671371
m = target.defModule
13681372
else:
13691373
m = None
1370-
fulluniq = nvqppPrefix + target.uniqName
1371-
fn = recover_func_op(self.module, fulluniq)
1372-
funcTy = fn.type
1373-
if target.firstLiftedPos:
1374-
moduloInTys = funcTy.inputs[:target.firstLiftedPos]
1375-
else:
1376-
moduloInTys = funcTy.inputs
1377-
callableTy = cc.CallableType.get(self.ctx, moduloInTys, funcTy.results)
1374+
funcTy = target.signature.get_lifted_type()
1375+
callableTy = target.signature.get_callable_type()
13781376
with insPt, self.loc:
13791377
lamb = cc.CreateLambdaOp(callableTy, loc=self.loc)
13801378
lamb.attributes.__setitem__('function_type', TypeAttr.get(funcTy))
13811379
initRegion = lamb.initRegion
1382-
initBlock = Block.create_at_start(initRegion, moduloInTys)
1380+
initBlock = Block.create_at_start(initRegion, target.arg_types())
13831381
inner = InsertionPoint(initBlock)
13841382
with inner:
13851383
vs = []
13861384
for ba in initBlock.arguments:
13871385
vs.append(ba)
1388-
for i, a in enumerate(target.liftedArgs):
1389-
v = recover_value_of(a, m)
1386+
for var in target.captured_variables():
1387+
v = recover_value_of(var.name, m)
13901388
if isa_kernel_decorator(v):
13911389
# The recursive step
13921390
v = self.resolve_callable_arg(inner, v)
13931391
else:
1394-
argTy = funcTy.inputs[target.firstLiftedPos + i]
1395-
v = self.__getMLIRValueFromPythonArg(v, argTy)
1392+
v = self.__getMLIRValueFromPythonArg(v, var.type)
13961393
vs.append(v)
13971394
if funcTy.results:
13981395
call = func.CallOp(fn, vs).result

0 commit comments

Comments
 (0)