Skip to content

Commit b2f3c25

Browse files
* Addressing review comments - use the type from discriminate (not i1
by default) Signed-off-by: Pradnya Khalate <pkhalate@nvidia.com>
1 parent 5ec6d9c commit b2f3c25

File tree

2 files changed

+101
-21
lines changed

2 files changed

+101
-21
lines changed

lib/Optimizer/Transforms/ExpandMeasurements.cpp

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,22 @@ class ExpandRewritePattern : public OpRewritePattern<A> {
5151
rewriter.template create<arith::AddIOp>(loc, totalToRead, vecSz);
5252
}
5353

54-
// 2. Create the buffer.
55-
auto i1Ty = rewriter.getI1Type();
54+
// 2. Determine the element type from users (default to `i1`).
55+
Type elemTy = rewriter.getI1Type();
56+
for (auto *user : measureOp.getMeasOut().getUsers()) {
57+
if (auto disc = dyn_cast<quake::DiscriminateOp>(user))
58+
if (auto svTy =
59+
dyn_cast<cudaq::cc::StdvecType>(disc.getResult().getType())) {
60+
elemTy = svTy.getElementType();
61+
break;
62+
}
63+
}
64+
// 3. Create the buffer.
5665
auto i8Ty = rewriter.getI8Type();
5766
Value buff =
5867
rewriter.template create<cudaq::cc::AllocaOp>(loc, i8Ty, totalToRead);
5968

60-
// 3. Measure each individual qubit and insert the result, in order, into
69+
// 4. Measure each individual qubit and insert the result, in order, into
6170
// the buffer. For registers/vectors, loop over the entire set of qubits.
6271
Value buffOff = rewriter.template create<arith::ConstantIntOp>(loc, 0, 64);
6372
Value one = rewriter.template create<arith::ConstantIntOp>(loc, 1, 64);
@@ -66,7 +75,7 @@ class ExpandRewritePattern : public OpRewritePattern<A> {
6675
if (isa<quake::RefType>(v.getType())) {
6776
auto meas = rewriter.template create<A>(loc, measTy, v).getMeasOut();
6877
auto bit =
69-
rewriter.template create<quake::DiscriminateOp>(loc, i1Ty, meas);
78+
rewriter.template create<quake::DiscriminateOp>(loc, elemTy, meas);
7079
Value addr = rewriter.template create<cudaq::cc::ComputePtrOp>(
7180
loc, cudaq::cc::PointerType::get(i8Ty), buff, buffOff);
7281
auto bitByte = rewriter.template create<cudaq::cc::CastOp>(
@@ -84,7 +93,7 @@ class ExpandRewritePattern : public OpRewritePattern<A> {
8493
builder.template create<quake::ExtractRefOp>(loc, v, iv);
8594
auto meas = builder.template create<A>(loc, measTy, qv);
8695
auto bit = builder.template create<quake::DiscriminateOp>(
87-
loc, i1Ty, meas.getMeasOut());
96+
loc, elemTy, meas.getMeasOut());
8897
if (auto registerName = measureOp.getRegisterNameAttr())
8998
meas.setRegisterName(registerName);
9099
Value offset =
@@ -99,15 +108,15 @@ class ExpandRewritePattern : public OpRewritePattern<A> {
99108
}
100109
}
101110

102-
// 4. Use the buffer as an initialization expression and create the
103-
// std::vec<bool> value.
104-
auto stdvecTy = cudaq::cc::StdvecType::get(rewriter.getContext(), i1Ty);
111+
// 5. Use the buffer as an initialization expression and create the
112+
// std::vec<iN> value.
113+
auto stdvecTy = cudaq::cc::StdvecType::get(rewriter.getContext(), elemTy);
105114
for (auto *out : measureOp.getMeasOut().getUsers())
106115
if (auto disc = dyn_cast_if_present<quake::DiscriminateOp>(out)) {
107-
auto ptrArrI1Ty =
108-
cudaq::cc::PointerType::get(cudaq::cc::ArrayType::get(i1Ty));
116+
auto ptrArrTy =
117+
cudaq::cc::PointerType::get(cudaq::cc::ArrayType::get(elemTy));
109118
auto buffCast =
110-
rewriter.template create<cudaq::cc::CastOp>(loc, ptrArrI1Ty, buff);
119+
rewriter.template create<cudaq::cc::CastOp>(loc, ptrArrTy, buff);
111120
rewriter.template replaceOpWithNewOp<cudaq::cc::StdvecInitOp>(
112121
disc, stdvecTy, buffCast, totalToRead);
113122
}
@@ -119,7 +128,7 @@ class ExpandRewritePattern : public OpRewritePattern<A> {
119128

120129
// Expand a `quake.discriminate` on a `!quake.measurements<N>` value into
121130
// individual `get_measure` + `discriminate` operations, producing a
122-
// `!cc.stdvec<i1>`.
131+
// `!cc.stdvec<iK>` where K is inferred from the original result type.
123132
class ExpandDiscriminatePattern
124133
: public OpRewritePattern<quake::DiscriminateOp> {
125134
public:
@@ -133,7 +142,9 @@ class ExpandDiscriminatePattern
133142
return failure();
134143

135144
auto loc = discOp.getLoc();
136-
auto i1Ty = rewriter.getI1Type();
145+
auto stdvecResTy =
146+
cast<cudaq::cc::StdvecType>(discOp.getResult().getType());
147+
auto elemTy = stdvecResTy.getElementType();
137148
auto i8Ty = rewriter.getI8Type();
138149
Value totalToRead;
139150

@@ -150,7 +161,8 @@ class ExpandDiscriminatePattern
150161
std::size_t n = measTy.getSize();
151162
for (std::size_t i = 0; i < n; ++i) {
152163
Value getMeas = rewriter.create<quake::GetMeasureOp>(loc, measVal, i);
153-
Value bit = rewriter.create<quake::DiscriminateOp>(loc, i1Ty, getMeas);
164+
Value bit =
165+
rewriter.create<quake::DiscriminateOp>(loc, elemTy, getMeas);
154166
Value idx = rewriter.create<arith::ConstantIntOp>(loc, i, 64);
155167
Value addr = rewriter.create<cudaq::cc::ComputePtrOp>(
156168
loc, cudaq::cc::PointerType::get(i8Ty), buff, idx);
@@ -160,12 +172,12 @@ class ExpandDiscriminatePattern
160172
}
161173
}
162174

163-
auto stdvecTy = cudaq::cc::StdvecType::get(rewriter.getContext(), i1Ty);
164-
auto ptrArrI1Ty =
165-
cudaq::cc::PointerType::get(cudaq::cc::ArrayType::get(i1Ty));
166-
auto buffCast = rewriter.create<cudaq::cc::CastOp>(loc, ptrArrI1Ty, buff);
167-
rewriter.replaceOpWithNewOp<cudaq::cc::StdvecInitOp>(discOp, stdvecTy,
168-
buffCast, totalToRead);
175+
auto ptrArrElemTy =
176+
cudaq::cc::PointerType::get(cudaq::cc::ArrayType::get(elemTy));
177+
auto buffCast =
178+
rewriter.create<cudaq::cc::CastOp>(loc, ptrArrElemTy, buff);
179+
rewriter.replaceOpWithNewOp<cudaq::cc::StdvecInitOp>(
180+
discOp, stdvecResTy, buffCast, totalToRead);
169181
return success();
170182
}
171183
};

test/Transforms/expand_measurements.qke

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ func.func @discriminate_single(%m : !quake.measure) -> i1 {
5151
// CHECK: return %[[VAL_1]] : i1
5252
// CHECK: }
5353

54-
5554
func.func @discriminate_measurements_1(%ms : !quake.measurements<1>) -> !cc.stdvec<i1> {
5655
%bits = quake.discriminate %ms : (!quake.measurements<1>) -> !cc.stdvec<i1>
5756
return %bits : !cc.stdvec<i1>
@@ -71,3 +70,72 @@ func.func @discriminate_measurements_1(%ms : !quake.measurements<1>) -> !cc.stdv
7170
// CHECK: %[[VAL_9:.*]] = cc.stdvec_init %[[VAL_8]], %[[VAL_1]] : (!cc.ptr<!cc.array<i1 x ?>>, i64) -> !cc.stdvec<i1>
7271
// CHECK: return %[[VAL_9]] : !cc.stdvec<i1>
7372
// CHECK: }
73+
74+
func.func @discriminate_measurements_i4(%ms : !quake.measurements<2>) -> !cc.stdvec<i4> {
75+
%bits = quake.discriminate %ms : (!quake.measurements<2>) -> !cc.stdvec<i4>
76+
return %bits : !cc.stdvec<i4>
77+
}
78+
79+
// CHECK-LABEL: func.func @discriminate_measurements_i4(
80+
// CHECK-SAME: %[[VAL_0:.*]]: !quake.measurements<2>) -> !cc.stdvec<i4> {
81+
// CHECK: %[[VAL_1:.*]] = arith.constant 2 : i64
82+
// CHECK: %[[VAL_2:.*]] = cc.alloca i8{{\[}}%[[VAL_1]] : i64]
83+
// CHECK: %[[VAL_3:.*]] = quake.get_measure %[[VAL_0]][0] : (!quake.measurements<2>) -> !quake.measure
84+
// CHECK: %[[VAL_4:.*]] = quake.discriminate %[[VAL_3]] : (!quake.measure) -> i4
85+
// CHECK: %[[VAL_5:.*]] = arith.constant 0 : i64
86+
// CHECK: %[[VAL_6:.*]] = cc.compute_ptr %[[VAL_2]]{{\[}}%[[VAL_5]]] : (!cc.ptr<!cc.array<i8 x ?>>, i64) -> !cc.ptr<i8>
87+
// CHECK: %[[VAL_7:.*]] = cc.cast unsigned %[[VAL_4]] : (i4) -> i8
88+
// CHECK: cc.store %[[VAL_7]], %[[VAL_6]] : !cc.ptr<i8>
89+
// CHECK: %[[VAL_8:.*]] = quake.get_measure %[[VAL_0]][1] : (!quake.measurements<2>) -> !quake.measure
90+
// CHECK: %[[VAL_9:.*]] = quake.discriminate %[[VAL_8]] : (!quake.measure) -> i4
91+
// CHECK: %[[VAL_10:.*]] = arith.constant 1 : i64
92+
// CHECK: %[[VAL_11:.*]] = cc.compute_ptr %[[VAL_2]]{{\[}}%[[VAL_10]]] : (!cc.ptr<!cc.array<i8 x ?>>, i64) -> !cc.ptr<i8>
93+
// CHECK: %[[VAL_12:.*]] = cc.cast unsigned %[[VAL_9]] : (i4) -> i8
94+
// CHECK: cc.store %[[VAL_12]], %[[VAL_11]] : !cc.ptr<i8>
95+
// CHECK: %[[VAL_13:.*]] = cc.cast %[[VAL_2]] : (!cc.ptr<!cc.array<i8 x ?>>) -> !cc.ptr<!cc.array<i4 x ?>>
96+
// CHECK: %[[VAL_14:.*]] = cc.stdvec_init %[[VAL_13]], %[[VAL_1]] : (!cc.ptr<!cc.array<i4 x ?>>, i64) -> !cc.stdvec<i4>
97+
// CHECK: return %[[VAL_14]] : !cc.stdvec<i4>
98+
// CHECK: }
99+
100+
func.func @expand_mz_veq_i3() -> !cc.stdvec<i3> {
101+
%0 = quake.alloca !quake.veq<2>
102+
%measOut = quake.mz %0 : (!quake.veq<2>) -> !quake.measurements<2>
103+
%bits = quake.discriminate %measOut : (!quake.measurements<2>) -> !cc.stdvec<i3>
104+
return %bits : !cc.stdvec<i3>
105+
}
106+
107+
// CHECK-LABEL: func.func @expand_mz_veq_i3() -> !cc.stdvec<i3> {
108+
// CHECK: %[[VAL_0:.*]] = quake.alloca !quake.veq<2>
109+
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i64
110+
// CHECK: %[[VAL_2:.*]] = quake.veq_size %[[VAL_0]] : (!quake.veq<2>) -> i64
111+
// CHECK: %[[VAL_3:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : i64
112+
// CHECK: %[[VAL_4:.*]] = cc.alloca i8{{\[}}%[[VAL_3]] : i64]
113+
// CHECK: %[[VAL_5:.*]] = arith.constant 0 : i64
114+
// CHECK: %[[VAL_6:.*]] = arith.constant 1 : i64
115+
// CHECK: %[[VAL_7:.*]] = quake.veq_size %[[VAL_0]] : (!quake.veq<2>) -> i64
116+
// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i64
117+
// CHECK: %[[VAL_9:.*]] = arith.constant 1 : i64
118+
// CHECK: %[[VAL_10:.*]] = cc.loop while ((%[[VAL_11:.*]] = %[[VAL_8]]) -> (i64)) {
119+
// CHECK: %[[VAL_12:.*]] = arith.cmpi slt, %[[VAL_11]], %[[VAL_7]] : i64
120+
// CHECK: cc.condition %[[VAL_12]](%[[VAL_11]] : i64)
121+
// CHECK: } do {
122+
// CHECK: ^bb0(%[[VAL_13:.*]]: i64):
123+
// CHECK: %[[VAL_14:.*]] = quake.extract_ref %[[VAL_0]]{{\[}}%[[VAL_13]]] : (!quake.veq<2>, i64) -> !quake.ref
124+
// CHECK: %[[VAL_15:.*]] = quake.mz %[[VAL_14]] : (!quake.ref) -> !quake.measure
125+
// CHECK: %[[VAL_16:.*]] = quake.discriminate %[[VAL_15]] : (!quake.measure) -> i3
126+
// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_13]], %[[VAL_5]] : i64
127+
// CHECK: %[[VAL_18:.*]] = cc.compute_ptr %[[VAL_4]]{{\[}}%[[VAL_17]]] : (!cc.ptr<!cc.array<i8 x ?>>, i64) -> !cc.ptr<i8>
128+
// CHECK: %[[VAL_19:.*]] = cc.cast unsigned %[[VAL_16]] : (i3) -> i8
129+
// CHECK: cc.store %[[VAL_19]], %[[VAL_18]] : !cc.ptr<i8>
130+
// CHECK: cc.continue %[[VAL_13]] : i64
131+
// CHECK: } step {
132+
// CHECK: ^bb0(%[[VAL_20:.*]]: i64):
133+
// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_9]] : i64
134+
// CHECK: cc.continue %[[VAL_21]] : i64
135+
// CHECK: } {invariant}
136+
// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_5]], %[[VAL_7]] : i64
137+
// CHECK: %[[VAL_23:.*]] = cc.cast %[[VAL_4]] : (!cc.ptr<!cc.array<i8 x ?>>) -> !cc.ptr<!cc.array<i3 x ?>>
138+
// CHECK: %[[VAL_24:.*]] = cc.stdvec_init %[[VAL_23]], %[[VAL_3]] : (!cc.ptr<!cc.array<i3 x ?>>, i64) -> !cc.stdvec<i3>
139+
// CHECK: return %[[VAL_24]] : !cc.stdvec<i3>
140+
// CHECK: }
141+

0 commit comments

Comments
 (0)