forked from NVIDIA/cuda-quantum
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathast_bridge.py
More file actions
5386 lines (4767 loc) · 242 KB
/
ast_bridge.py
File metadata and controls
5386 lines (4767 loc) · 242 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# ============================================================================ #
# Copyright (c) 2022 - 2026 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #
import ast
import importlib
import inspect
import textwrap
import numpy as np
import os
import sys
import types
from collections import deque
from cudaq.mlir._mlir_libs._quakeDialects import (cudaq_runtime, load_intrinsic,
gen_vector_of_complex_constant
)
from cudaq.kernel_types import qview
from cudaq.mlir.dialects import arith, cc, complex, func, math, quake
from cudaq.mlir.ir import (BoolAttr, Block, BlockArgument, Context, ComplexType,
DenseBoolArrayAttr, DenseI32ArrayAttr,
DenseI64ArrayAttr, DictAttr, F32Type, F64Type,
FlatSymbolRefAttr, FloatAttr, FunctionType,
InsertionPoint, IntegerAttr, IntegerType, Location,
Module, StringAttr, SymbolTable, TypeAttr, UnitAttr)
from cudaq.mlir.passmanager import PassManager
from .analysis import ValidateArgumentAnnotations, ValidateReturnStatements
from .kernel_signature import KernelSignature
from .utils import (Color, globalRegisteredOperations, globalRegisteredTypes,
nvqppPrefix, mlirTypeFromAnnotation, mlirTypeFromPyType,
mlirTypeToPyType, getMLIRContext, recover_func_op,
is_recovered_value_ok, recover_value_of_or_none,
cudaq__unique_attr_name, mlirTryCreateStructType,
resolve_qualified_symbol)
State = cudaq_runtime.State
# This file implements the CUDA-Q Python AST to MLIR conversion.
# It provides a `PyASTBridge` class that implements the `ast.NodeVisitor` type
# to walk the Python AST for a `cudaq.kernel` annotated function and generate
# valid MLIR code using `Quake`, `CC`, `Arith`, and `Math` dialects.
# CC Dialect `ComputePtrOp` in C++ sets the
# dynamic index as `std::numeric_limits<int32_t>::min()`
# (see CCOps.tc line 898). We'll duplicate that
# here by just setting it manually
kDynamicPtrIndex: int = -2147483648
ALLOWED_TYPES_IN_A_DATACLASS = [int, float, bool, qview]
class PyScopedSymbolTable(object):
class Scope(object):
def __init__(self, scope_root, parent=None):
assert isinstance(scope_root, Block) and hasattr(
scope_root, 'owner')
self.root = scope_root
self.parent = parent
self._blockID = -1
self._parentBlocks = deque()
self._symbols = {}
def beginBlock(self):
self._blockID += 1
self._parentBlocks.append(self._blockID)
def endBlock(self):
assert self._parentBlocks
self._parentBlocks.pop()
def isDefined(self, symbol):
return symbol in self._symbols
def tryGet(self, symbol):
"""
Returns the value of the given symbol in this scope,
as well as a boolean indicating whether the value is
valid in the current scope.
"""
if not symbol in self._symbols:
return None, False
# We need to make sure that the symbol is not only defined,
# but also accessible at this location in the MLIR we generate.
# To check for this, the symbol table keeps track of which block
# a symbol is defined in. Because/if some variables are stored
# as values in the symbol table, a value defined in an inner
# block may not be accessible in the outer block (MLIR fails
# with "operand does not dominate this use"). We hence fail
# with a comprehensive error if a symbol is defined according
# to Python scoping rules, but not valid to use at the current
# location in the generated MLIR.
value, sid = self._symbols[symbol]
return value, sid in self._parentBlocks
def addOrUpdate(self, symbol, value):
assert self._parentBlocks
if hasattr(value, 'owner'):
if value.owner == self.root:
self._symbols[symbol] = (value, 0)
return
if (hasattr(value.owner, 'parent') and
value.owner.parent == self.root.owner):
self._symbols[symbol] = (value, 0)
return
self._symbols[symbol] = (value, self._parentBlocks[-1])
@property
def depth(self):
return len(self._parentBlocks)
def isFromParentBlock(self, symbol):
if symbol in self._symbols:
_, sid = self._symbols[symbol]
return (sid in self._parentBlocks and
sid != self._parentBlocks[-1])
return False
def __init__(self, error_handler=None):
def default_error_handler(msg):
raise RuntimeError(msg)
self._scope = None
self.emitError = error_handler or default_error_handler
def pushScope(self, scope_root):
self._scope = PyScopedSymbolTable.Scope(scope_root, parent=self._scope)
def popScope(self):
if not self._scope:
self.emitError("symbol table has no scopes to pop")
elif self._scope.depth > 0:
self.emitError("unfinished block(s) in symbol table")
else:
self._scope = self._scope.parent
def beginBlock(self):
self._scope.beginBlock()
def endBlock(self):
self._scope.endBlock()
@property
def scopeDepth(self):
return self._scope.depth
@property
def scopeRoot(self):
return self._scope.root
def __contains__(self, symbol):
"""
Returns True if and only if a symbol with the given name
is defined according to Python scoping rules.
Note that depending on how values are represented in MLIR,
it is possible that even though a symbol is defined
according to Python rules, there are limitations for where
that symbol can be used in the MLIR translation. In that
case, retrieving the symbol from the symbol table will
fail with an appropriate error.
"""
scope = self._scope
while scope:
if scope.isDefined(symbol):
return True
scope = scope.parent
return False
def __setitem__(self, symbol, value):
"""
Adds or updates the given symbol in the symbol table.
Automatically adjusts the block association if the given
value is a function argument or an allocation at the
beginning of the scope.
"""
if not self._scope:
self.emitError("no scope defined")
target_scope = self._scope
# special handling for function arguments
if (BlockArgument.isinstance(value) and
isinstance(value.owner.owner, func.FuncOp)):
# add to the scope of the closest function definition
while (target_scope.parent and
not isinstance(target_scope.root.owner, func.FuncOp)):
target_scope = target_scope.parent
target_scope.addOrUpdate(symbol, value)
def __getitem__(self, symbol):
"""
Retrieves the value of the given symbol from the symbol table.
Fails with a comprehensive error if the symbol is not defined,
or if the symbol if defined but cannot be used at the current
location (due to how it is represented in MLIR).
"""
scope, value = self._scope, None
while scope and not value:
value, valid = scope.tryGet(symbol)
scope = scope.parent
if not value:
self.emitError(f"name '{symbol}' is not defined")
elif not valid:
# We have a variable that is defined in an inner scope,
# but not allocated in the main function body.
# This case deviates from Python behavior, and we give
# a hopefully comprehensive enough error.
self.emitError(f"variable of type {value.type} " +
"is defined in a prior block and cannot be " +
"accessed or changed outside that block" +
os.linesep + f"(offending source -> {symbol})")
else:
return value
def isFromParentScope(self, symbol):
"""
Returns True if and only if the given symbol is defined
in an outer scope.
"""
if not self._scope or self._scope.isDefined(symbol):
return False
scope = self._scope.parent
while scope:
if scope.isDefined(symbol):
return True
scope = scope.parent
return False
def isFromParentBlock(self, symbol):
"""
Returns True if and only if the given symbol is defined in
a parent block of the current scope.
"""
return self._scope and self._scope.isFromParentBlock(symbol)
@property
def isEmpty(self):
"""
Returns true if and only if there are no remaining scopes frames.
"""
return not self._scope
@property
def isInnerScope(self):
"""
Returns true if and only if the current scope has a parent scope defined.
"""
return self._scope and self._scope.parent is not None
class CompilerError(RuntimeError):
"""
Custom exception class for improved error diagnostics.
"""
def __init__(self, *args, **kwargs):
RuntimeError.__init__(self, *args, **kwargs)
class PyStack(object):
'''
Takes care of managing values produced while vising Python
AST nodes. Each visit to a node is expected to match one
stack frame. Values produced (meaning pushed) by child frames
are accessible (meaning can be popped) by the parent. A frame
cannot access the value it produced (it is owned by the parent).
'''
class Frame(object):
def __init__(self, parent=None):
self.parent = parent
self.entries = None
def __init__(self, error_handler=None):
def default_error_handler(msg):
raise RuntimeError(msg)
self._frame = None
self.emitError = error_handler or default_error_handler
def pushFrame(self):
'''
A new frame should be pushed to process a new node in the AST.
'''
if self._frame and not self._frame.entries:
self._frame.entries = deque()
self._frame = PyStack.Frame(parent=self._frame)
def popFrame(self):
'''
A frame should be popped once a node in the AST has been processed.
'''
if not self._frame:
self.emitError("stack has no frames to pop")
elif self._frame.entries:
self.emitError(
"all values must be processed before popping a frame")
else:
self._frame = self._frame.parent
def pushValue(self, value):
'''
Pushes a value to the make it available to the parent frame.
'''
if not self._frame:
self.emitError("cannot push value to empty stack")
elif not self._frame.parent:
self.emitError("no parent frame is defined to push values to")
else:
self._frame.parent.entries.append(value)
def popValue(self):
'''
Pops the most recently produced (pushed) value by a child frame.
'''
if not self._frame:
self.emitError("value stack is empty")
elif not self._frame.entries:
# This is the only error that may be directly user-facing even when
# the bridge is doing its processing correctly. We hence give a
# somewhat general error. For internal purposes, the error might be
# better stated as something like: either this frame has not had a
# child or the child did not produce any values
self.emitError("no valid value was created")
else:
return self._frame.entries.pop()
@property
def isEmpty(self):
'''
Returns true if and only if there are no remaining stack frames.
'''
return not self._frame
@property
def currentNumValues(self):
'''
Returns the number of values that are accessible for processing by the current frame.
'''
if not self._frame:
self.emitError("no frame defined for empty stack")
elif self._frame.entries:
return len(self._frame.entries)
return 0
def recover_kernel_decorator(name):
from .kernel_decorator import isa_kernel_decorator
for frameinfo in inspect.stack():
frame = frameinfo.frame
if name in frame.f_locals:
if isa_kernel_decorator(frame.f_locals[name]):
return frame.f_locals[name]
return None
if name in frame.f_globals:
if isa_kernel_decorator(frame.f_globals[name]):
return frame.f_globals[name]
return None
return None
class PyASTBridge(ast.NodeVisitor):
"""
The `PyASTBridge` class implements the `ast.NodeVisitor` type to convert a
python function definition (annotated with cudaq.kernel) to an MLIR
`ModuleOp` containing a `func.FuncOp` representative of the original python
function but leveraging the Quake and CC dialects provided by CUDA-Q. This
class keeps track of a MLIR Value stack that is pushed to and popped from
during visitation of the function AST nodes. We leverage the auto-generated
MLIR Python bindings for the internal C++ CUDA-Q dialects to build up the
MLIR code.
"""
def __init__(self,
signature: KernelSignature,
*,
uniqueId=None,
kernelModuleName=None,
locationOffset=('', 0),
verbose=False):
"""
The constructor. Initializes the `mlir.Value` stack, the `mlir.Context`,
and the `mlir.Module` that we will be building upon. This class keeps
track of a symbol table, which maps variable names to constructed
`mlir.Values`.
As the AST is visited, the kernel signature will be extended to capture
variables as required.
"""
def node_error(msg):
self.emitFatalError(f'processing error - {msg}', self.currentNode)
self.symbolTable = PyScopedSymbolTable(error_handler=node_error)
self.valueStack = PyStack(error_handler=node_error)
self.signature = signature
self.uniqueId = uniqueId
self.kernelModuleName = kernelModuleName
self.ctx = getMLIRContext()
self.loc = Location.unknown(context=self.ctx)
self.module = Module.create(self.loc)
self.locationOffset = locationOffset
self.indent_level = 0
self.indent = 4 * " "
self.buildingFunctionBody = False
self.inForBodyStack = deque()
self.inIfStmtBlockStack = 0
self.currentAssignVariableName = None
self.walkingReturnNode = False
self.controlNegations = []
self.pushPointerValue = False
self.isSubscriptRoot = False
self.verbose = verbose
self.currentNode = None
def debug_msg(self, msg, node=None):
if self.verbose:
print(f'{self.indent * self.indent_level}{msg()}')
if node is not None:
try:
print(
textwrap.indent(ast.unparse(node),
(self.indent *
(self.indent_level + 1))))
except:
pass
def emitWarning(self, msg, astNode=None):
"""
Emit a warning, providing the user with source file information and
the offending code.
"""
codeFile = os.path.basename(self.locationOffset[0])
if astNode == None:
astNode = self.currentNode
lineNumber = ('' if astNode == None else astNode.lineno +
self.locationOffset[1] - 1)
print(Color.BOLD, end='')
msg = (codeFile + ":" + str(lineNumber) + ": " + Color.YELLOW +
"warning: " + Color.END + Color.BOLD + msg +
("\n\t (offending source -> " + ast.unparse(astNode) + ")"
if hasattr(ast, 'unparse') and astNode is not None else '') +
Color.END)
print(msg)
def emitFatalError(self, msg, astNode=None):
"""
Emit a fatal error, providing the user with source file information and
the offending code.
"""
codeFile = os.path.basename(self.locationOffset[0])
if astNode == None:
astNode = self.currentNode
lineNumber = '' if astNode == None or not hasattr(
astNode, 'lineno') else astNode.lineno + self.locationOffset[1] - 1
try:
offending_source = "\n\t (offending source -> " + ast.unparse(
astNode) + ")"
except:
offending_source = ''
print(Color.BOLD, end='')
msg = codeFile + ":" + str(
lineNumber
) + ": " + Color.RED + "error: " + Color.END + Color.BOLD + msg + offending_source + Color.END
raise CompilerError(msg)
def getVeqType(self, size=None):
"""
Return a `quake.VeqType`. Pass the size of the `quake.veq` if known.
"""
if size == None:
return quake.VeqType.get()
return quake.VeqType.get(size)
def getRefType(self):
"""
Return a `quake.RefType`.
"""
return quake.RefType.get()
def isQuantumType(self, ty):
"""
Return True if the given type is quantum (is a `VeqType` or `RefType`).
Return False otherwise.
"""
return quake.RefType.isinstance(ty) or quake.VeqType.isinstance(
ty) or quake.StruqType.isinstance(ty)
def isFunctionArgument(self, value):
return (BlockArgument.isinstance(value) and
isinstance(value.owner.owner, func.FuncOp))
def containsList(self, ty, innerListsOnly=False):
"""
Returns true if the give type is a vector or contains
items that are vectors.
"""
if cc.StdvecType.isinstance(ty):
return (not innerListsOnly or
self.containsList(cc.StdvecType.getElementType(ty)))
if not cc.StructType.isinstance(ty):
return False
eleTys = cc.StructType.getTypes(ty)
return any((self.containsList(t) for t in eleTys))
def getIntegerType(self, width=64):
"""
Return an MLIR `IntegerType` of the given bit width (defaults to 64
bits).
"""
return IntegerType.get_signless(width)
def getIntegerAttr(self, type, value):
"""
Return an MLIR Integer Attribute of the given `IntegerType`.
"""
return IntegerAttr.get(type, value)
def getFloatType(self, width=64):
"""
Return an MLIR float type (single or double precision).
"""
# Note:
# `numpy.float64` is the same as `float` type, with width of 64 bit.
# `numpy.float32` type has width of 32 bit.
if width == 64:
return F64Type.get()
elif width == 32:
return F32Type.get()
else:
self.emitFatalError(
f'unsupported width {width} requested for float type',
self.currentNode)
def getFloatAttr(self, type, value):
"""
Return an MLIR float attribute (single or double precision).
"""
return FloatAttr.get(type, value)
def getConstantFloat(self, value, width=64):
"""
Create a constant float operation and return its MLIR result Value.
Takes as input the concrete float value.
"""
ty = self.getFloatType(width=width)
return self.getConstantFloatWithType(value, ty)
def getConstantFloatWithType(self, value, ty):
"""
Create a constant float operation and return its MLIR result Value.
Takes as input the concrete float value.
"""
return arith.ConstantOp(ty, self.getFloatAttr(ty, value)).result
def getComplexType(self, width=64):
"""
Return an MLIR complex type (single or double precision).
"""
# Note:
# `numpy.complex128` is the same as `complex` type, with element width
# of 64bit (`np.complex64` and `float`) `numpy.complex64` type has
# element type of `np.float32`.
return self.getComplexTypeWithElementType(
self.getFloatType(width=width))
def getComplexTypeWithElementType(self, eTy):
"""
Return an MLIR complex type (single or double precision).
"""
return ComplexType.get(eTy)
def getConstantComplex(self, value, width=64):
"""
Create a constant complex operation and return its MLIR result Value.
Takes as input the concrete complex value.
"""
ty = self.getComplexType(width=width)
return complex.CreateOp(ty,
self.getConstantFloat(value.real, width=width),
self.getConstantFloat(value.imag,
width=width)).result
def getConstantComplexWithElementType(self, value, eTy):
"""
Create a constant complex operation and return its MLIR result Value.
Takes as input the concrete complex value.
"""
ty = self.getComplexTypeWithElementType(eTy)
return complex.CreateOp(ty,
self.getConstantFloatWithType(value.real, eTy),
self.getConstantFloatWithType(value.imag,
eTy)).result
def getConstantInt(self, value, width=64):
"""
Create a constant integer operation and return its MLIR result Value.
Takes as input the concrete integer value. Can specify the integer bit
width.
"""
ty = self.getIntegerType(width)
return arith.ConstantOp(ty, self.getIntegerAttr(ty, value)).result
def __arithmetic_to_bool(self, value):
"""
Converts an integer or floating point value to a bool by
comparing it to zero.
"""
if self.getIntegerType(1) == value.type:
return value
if IntegerType.isinstance(value.type):
zero = self.getConstantInt(0, width=IntegerType(value.type).width)
condPred = IntegerAttr.get(self.getIntegerType(), 1)
return arith.CmpIOp(condPred, value, zero).result
elif F32Type.isinstance(value.type):
zero = self.getConstantFloat(0, width=32)
condPred = IntegerAttr.get(self.getIntegerType(), 13)
return arith.CmpFOp(condPred, value, zero).result
elif F64Type.isinstance(value.type):
zero = self.getConstantFloat(0, width=64)
condPred = IntegerAttr.get(self.getIntegerType(), 13)
return arith.CmpFOp(condPred, value, zero).result
else:
self.emitFatalError("value cannot be converted to bool",
self.currentNode)
def changeOperandToType(self, ty, operand, allowDemotion=False):
"""
Change the type of an operand to a specified type. This function
primarily handles type conversions and promotions to higher types
(complex > float > int). Demotion of floating type to integer is not
allowed by default. Regardless of whether demotion is allowed, types
will be cast to smaller widths.
"""
if ty == operand.type:
return operand
if cc.CallableType.isinstance(ty):
fctTy = cc.CallableType.getFunctionType(ty)
if fctTy == operand.type:
return operand
self.emitFatalError(
f'cannot convert value of type {operand.type} to '
f'the requested type {fctTy}', self.currentNode)
if ComplexType.isinstance(ty):
complexType = ComplexType(ty)
floatType = complexType.element_type
if ComplexType.isinstance(operand.type):
otherComplexType = ComplexType(operand.type)
otherFloatType = otherComplexType.element_type
if (floatType != otherFloatType):
real = self.changeOperandToType(
floatType,
complex.ReOp(operand).result,
allowDemotion=allowDemotion)
imag = self.changeOperandToType(
floatType,
complex.ImOp(operand).result,
allowDemotion=allowDemotion)
return complex.CreateOp(complexType, real, imag).result
else:
real = self.changeOperandToType(floatType,
operand,
allowDemotion=allowDemotion)
imag = self.getConstantFloatWithType(0.0, floatType)
return complex.CreateOp(complexType, real, imag).result
if (cc.StdvecType.isinstance(ty)):
if cc.StdvecType.isinstance(operand.type):
eleTy = cc.StdvecType.getElementType(ty)
return self.__copyVectorAndConvertElements(
operand,
eleTy,
allowDemotion=allowDemotion,
alwaysCopy=False)
if (cc.StructType.isinstance(ty)):
if cc.StructType.isinstance(operand.type):
expectedEleTys = cc.StructType.getTypes(ty)
currentEleTys = cc.StructType.getTypes(operand.type)
if len(expectedEleTys) == len(currentEleTys):
def conversion(idx, value):
return self.changeOperandToType(
expectedEleTys[idx],
value,
allowDemotion=allowDemotion)
return self.__copyStructAndConvertElements(
operand,
expectedTy=ty,
allowDemotion=allowDemotion,
conversion=conversion)
if F64Type.isinstance(ty):
if F32Type.isinstance(operand.type):
return cc.CastOp(ty, operand).result
if IntegerType.isinstance(operand.type):
zeroext = IntegerType(operand.type).width == 1
return cc.CastOp(ty, operand, sint=not zeroext,
zint=zeroext).result
if F32Type.isinstance(ty):
if F64Type.isinstance(operand.type):
return cc.CastOp(ty, operand).result
if IntegerType.isinstance(operand.type):
zeroext = IntegerType(operand.type).width == 1
return cc.CastOp(ty, operand, sint=not zeroext,
zint=zeroext).result
if IntegerType.isinstance(ty):
if allowDemotion and (F64Type.isinstance(operand.type) or
F32Type.isinstance(operand.type)):
operand = cc.CastOp(ty, operand, sint=True, zint=False).result
if IntegerType.isinstance(operand.type):
requested_width = IntegerType(ty).width
operand_width = IntegerType(operand.type).width
if requested_width == operand_width:
return operand
elif requested_width < operand_width:
if requested_width == 1:
return self.__arithmetic_to_bool(operand)
return cc.CastOp(ty, operand).result
return cc.CastOp(ty,
operand,
sint=operand_width != 1,
zint=operand_width == 1).result
self.emitFatalError(
f'cannot convert value of type {operand.type} '
f'to the requested type {ty}', self.currentNode)
def simulationPrecision(self):
"""
Return precision for the current simulation backend, see
`cudaq_runtime.SimulationPrecision`.
"""
target = cudaq_runtime.get_target()
return target.get_precision()
def simulationDType(self):
"""
Return the data type for the current simulation backend, either
`numpy.complex128` or `numpy.complex64`.
"""
if self.simulationPrecision() == cudaq_runtime.SimulationPrecision.fp64:
return self.getComplexType(width=64)
return self.getComplexType(width=32)
def pushValue(self, value):
"""
Push an MLIR Value onto the stack for usage in a subsequent AST node
visit method.
"""
self.debug_msg(lambda: f'push {value}')
self.valueStack.pushValue(value)
def popValue(self):
"""
Pop an MLIR Value from the stack.
"""
val = self.valueStack.popValue()
self.debug_msg(lambda: f'pop {val}')
return val
def popAllValues(self, expectedNumVals):
values = [
self.popValue() for _ in range(self.valueStack.currentNumValues)
]
if len(values) != expectedNumVals:
self.emitFatalError(
"processing error - expression did not produce a valid "
"value in this context", self.currentNode)
return values
def pushForBodyStack(self, bodyBlockArgs):
"""
Indicate that we are entering a for loop body block.
"""
self.inForBodyStack.append(bodyBlockArgs)
def popForBodyStack(self):
"""
Indicate that we have left a for loop body block.
"""
self.inForBodyStack.pop()
def pushIfStmtBlockStack(self):
"""
Indicate that we are entering an if statement then or else block.
"""
self.inIfStmtBlockStack += 1
def popIfStmtBlockStack(self):
"""
Indicate that we have just left an if statement then or else block.
"""
assert self.inIfStmtBlockStack > 0
self.inIfStmtBlockStack -= 1
def isInForBody(self):
"""
Return True if the current insertion point is within a for body block.
"""
return len(self.inForBodyStack) > 0
def isInIfStmtBlock(self):
"""
Return True if the current insertion point is within an if statement
then or else block.
"""
return self.inIfStmtBlockStack > 0
def hasTerminator(self, block):
"""
Return True if the given Block has a Terminator operation.
"""
if len(block.operations) > 0:
return cudaq_runtime.isTerminator(
block.operations[len(block.operations) - 1])
return False
def isArithmeticType(self, type):
"""
Return True if the given type is an integer, float, or complex type.
"""
return IntegerType.isinstance(type) or F64Type.isinstance(
type) or F32Type.isinstance(type) or ComplexType.isinstance(type)
def __isSupportedNumpyFunction(self, id):
return id in ['sin', 'cos', 'sqrt', 'ceil', 'exp']
def __isSupportedVectorFunction(self, id):
return id in ['front', 'back', 'append']
def __isSimpleGate(self, id):
return id in ['h', 'x', 'y', 'z', 's', 't']
def __isAdjointSimpleGate(self, id):
return id in ['sdg', 'tdg']
def __isControlledSimpleGate(self, id):
if id == '' or id[0] != 'c':
return False
return self.__isSimpleGate(id[1:])
def __isRotationGate(self, id):
return id in ['rx', 'ry', 'rz', 'r1']
def __isControlledRotationGate(self, id):
if id == '' or id[0] != 'c':
return False
return self.__isRotationGate(id[1:])
def __isMeasurementGate(self, id):
return id in ['mx', 'my', 'mz']
def __isUnitaryGate(self, id):
return (self.__isSimpleGate(id) or self.__isRotationGate(id) or
self.__isAdjointSimpleGate(id) or
self.__isControlledSimpleGate(id) or
self.__isControlledRotationGate(id) or
id in ['swap', 'u3', 'exp_pauli'] or
id in globalRegisteredOperations)
def __createStdvecWithKnownValues(self, listElementValues):
assert (len(set((v.type for v in listElementValues))) == 1)
arrSize = self.getConstantInt(len(listElementValues))
elemTy = listElementValues[0].type
# If this is an `i1`, turns it into an `i8` array.
isBool = elemTy == self.getIntegerType(1)
if isBool:
elemTy = self.getIntegerType(8)
alloca = cc.AllocaOp(cc.PointerType.get(cc.ArrayType.get(elemTy)),
TypeAttr.get(elemTy),
seqSize=arrSize).result
for i, v in enumerate(listElementValues):
eleAddr = cc.ComputePtrOp(
cc.PointerType.get(elemTy), alloca, [self.getConstantInt(i)],
DenseI32ArrayAttr.get([kDynamicPtrIndex],
context=self.ctx)).result
if isBool:
# Cast the list value before assigning
v = self.changeOperandToType(self.getIntegerType(8), v)
cc.StoreOp(v, eleAddr)
# We still use `i1` as the vector element type for `cc.StdvecInitOp`.
vecTy = cc.StdvecType.get(elemTy) if not isBool else cc.StdvecType.get(
self.getIntegerType(1))
return cc.StdvecInitOp(vecTy, alloca, length=arrSize).result
def __createStructWithKnownValues(self, mlirVals, name=None):
structTy = mlirTryCreateStructType([item.type for item in mlirVals],
name=name,
context=self.ctx)
if structTy is None:
self.emitFatalError(
"Hybrid quantum-classical data types and nested "
"quantum structs are not allowed.", self.currentNode)
if quake.StruqType.isinstance(structTy):
# If we have a quantum struct. We cannot allocate classical
# memory and load / store quantum type values to that memory
# space, so use `quake.MakeStruqOp`.
result = quake.MakeStruqOp(structTy, mlirVals).result
else:
result = cc.UndefOp(structTy)
for idx, element in enumerate(mlirVals):
result = cc.InsertValueOp(
structTy, result, element,
DenseI64ArrayAttr.get([idx], context=self.ctx)).result
return result
def getStructMemberIdx(self, memberName, structTy):
"""
For the given struct type and member variable name, return the index of
the variable in the struct and the specific MLIR type for the variable.
"""
if cc.StructType.isinstance(structTy):
structName = cc.StructType.getName(structTy)
else:
structName = quake.StruqType.getName(structTy)
structIdx = None
if structName == 'tuple':
self.emitFatalError('`tuple` does not support attribute access')
if not globalRegisteredTypes.isRegisteredClass(structName):
self.emitFatalError(f'Dataclass is not registered: {structName})')
_, userType = globalRegisteredTypes.getClassAttributes(structName)
for i, (k, _) in enumerate(userType.items()):
if k == memberName:
structIdx = i
break
if structIdx == None:
self.emitFatalError(
f'Invalid struct member: {structName}.{memberName} '
f'(members={[k for k,_ in userType.items()]})')
return structIdx, mlirTypeFromPyType(userType[memberName], self.ctx)
def __copyStructAndConvertElements(self,
struct,
expectedTy=None,
allowDemotion=False,
conversion=None):
"""
Creates a new struct on the stack. If a conversion is provided, applies
the conversion on each element before changing its type to match the
corresponding element type in `expectedTy`.
"""
assert cc.StructType.isinstance(struct.type)
if not expectedTy:
expectedTy = struct.type
assert cc.StructType.isinstance(expectedTy)
eleTys = cc.StructType.getTypes(struct.type)
expectedEleTys = cc.StructType.getTypes(expectedTy)
assert len(eleTys) == len(expectedEleTys)
returnVal = cc.UndefOp(expectedTy)
for idx, eleTy in enumerate(eleTys):
element = cc.ExtractValueOp(
eleTy, struct, [],
DenseI32ArrayAttr.get([idx], context=self.ctx)).result
element = conversion(idx, element) if conversion else element
element = self.changeOperandToType(expectedEleTys[idx],
element,
allowDemotion=allowDemotion)
returnVal = cc.InsertValueOp(
expectedTy, returnVal, element,
DenseI64ArrayAttr.get([idx], context=self.ctx)).result
return returnVal
# Create a new vector with source elements converted to the target element
# type if needed.
def __copyVectorAndConvertElements(self,
source,
targetEleType=None,
allowDemotion=False,
alwaysCopy=False,
conversion=None):
'''
Creates a new vector with the requested element type. Returns the
original vector if the requested element type already matches the
current element type unless `alwaysCopy` is set to True. If a
conversion is provided, applies the conversion to each element before
changing its type to match the `targetEleType`. If `alwaysCopy` is set
to True, return a shallow copy of the vector by default (conversion can
be used to create a deep copy).