@@ -51,13 +51,22 @@ class ExpandRewritePattern : public OpRewritePattern<A> {
5151 rewriter.template create <arith::AddIOp>(loc, totalToRead, vecSz);
5252 }
5353
54- // 2. Create the buffer.
55- auto i1Ty = rewriter.getI1Type ();
54+ // 2. Determine the element type from users (default to `i1`).
55+ Type elemTy = rewriter.getI1Type ();
56+ for (auto *user : measureOp.getMeasOut ().getUsers ()) {
57+ if (auto disc = dyn_cast<quake::DiscriminateOp>(user))
58+ if (auto svTy =
59+ dyn_cast<cudaq::cc::StdvecType>(disc.getResult ().getType ())) {
60+ elemTy = svTy.getElementType ();
61+ break ;
62+ }
63+ }
64+ // 3. Create the buffer.
5665 auto i8Ty = rewriter.getI8Type ();
5766 Value buff =
5867 rewriter.template create <cudaq::cc::AllocaOp>(loc, i8Ty, totalToRead);
5968
60- // 3 . Measure each individual qubit and insert the result, in order, into
69+ // 4 . Measure each individual qubit and insert the result, in order, into
6170 // the buffer. For registers/vectors, loop over the entire set of qubits.
6271 Value buffOff = rewriter.template create <arith::ConstantIntOp>(loc, 0 , 64 );
6372 Value one = rewriter.template create <arith::ConstantIntOp>(loc, 1 , 64 );
@@ -66,7 +75,7 @@ class ExpandRewritePattern : public OpRewritePattern<A> {
6675 if (isa<quake::RefType>(v.getType ())) {
6776 auto meas = rewriter.template create <A>(loc, measTy, v).getMeasOut ();
6877 auto bit =
69- rewriter.template create <quake::DiscriminateOp>(loc, i1Ty , meas);
78+ rewriter.template create <quake::DiscriminateOp>(loc, elemTy , meas);
7079 Value addr = rewriter.template create <cudaq::cc::ComputePtrOp>(
7180 loc, cudaq::cc::PointerType::get (i8Ty), buff, buffOff);
7281 auto bitByte = rewriter.template create <cudaq::cc::CastOp>(
@@ -84,7 +93,7 @@ class ExpandRewritePattern : public OpRewritePattern<A> {
8493 builder.template create <quake::ExtractRefOp>(loc, v, iv);
8594 auto meas = builder.template create <A>(loc, measTy, qv);
8695 auto bit = builder.template create <quake::DiscriminateOp>(
87- loc, i1Ty , meas.getMeasOut ());
96+ loc, elemTy , meas.getMeasOut ());
8897 if (auto registerName = measureOp.getRegisterNameAttr ())
8998 meas.setRegisterName (registerName);
9099 Value offset =
@@ -99,15 +108,15 @@ class ExpandRewritePattern : public OpRewritePattern<A> {
99108 }
100109 }
101110
102- // 4 . Use the buffer as an initialization expression and create the
103- // std::vec<bool > value.
104- auto stdvecTy = cudaq::cc::StdvecType::get (rewriter.getContext (), i1Ty );
111+ // 5 . Use the buffer as an initialization expression and create the
112+ // std::vec<iN > value.
113+ auto stdvecTy = cudaq::cc::StdvecType::get (rewriter.getContext (), elemTy );
105114 for (auto *out : measureOp.getMeasOut ().getUsers ())
106115 if (auto disc = dyn_cast_if_present<quake::DiscriminateOp>(out)) {
107- auto ptrArrI1Ty =
108- cudaq::cc::PointerType::get (cudaq::cc::ArrayType::get (i1Ty ));
116+ auto ptrArrTy =
117+ cudaq::cc::PointerType::get (cudaq::cc::ArrayType::get (elemTy ));
109118 auto buffCast =
110- rewriter.template create <cudaq::cc::CastOp>(loc, ptrArrI1Ty , buff);
119+ rewriter.template create <cudaq::cc::CastOp>(loc, ptrArrTy , buff);
111120 rewriter.template replaceOpWithNewOp <cudaq::cc::StdvecInitOp>(
112121 disc, stdvecTy, buffCast, totalToRead);
113122 }
@@ -119,7 +128,7 @@ class ExpandRewritePattern : public OpRewritePattern<A> {
119128
120129// Expand a `quake.discriminate` on a `!quake.measurements<N>` value into
121130// individual `get_measure` + `discriminate` operations, producing a
122- // `!cc.stdvec<i1>` .
131+ // `!cc.stdvec<iK>` where K is inferred from the original result type .
123132class ExpandDiscriminatePattern
124133 : public OpRewritePattern<quake::DiscriminateOp> {
125134public:
@@ -133,7 +142,9 @@ class ExpandDiscriminatePattern
133142 return failure ();
134143
135144 auto loc = discOp.getLoc ();
136- auto i1Ty = rewriter.getI1Type ();
145+ auto stdvecResTy =
146+ cast<cudaq::cc::StdvecType>(discOp.getResult ().getType ());
147+ auto elemTy = stdvecResTy.getElementType ();
137148 auto i8Ty = rewriter.getI8Type ();
138149 Value totalToRead;
139150
@@ -150,7 +161,8 @@ class ExpandDiscriminatePattern
150161 std::size_t n = measTy.getSize ();
151162 for (std::size_t i = 0 ; i < n; ++i) {
152163 Value getMeas = rewriter.create <quake::GetMeasureOp>(loc, measVal, i);
153- Value bit = rewriter.create <quake::DiscriminateOp>(loc, i1Ty, getMeas);
164+ Value bit =
165+ rewriter.create <quake::DiscriminateOp>(loc, elemTy, getMeas);
154166 Value idx = rewriter.create <arith::ConstantIntOp>(loc, i, 64 );
155167 Value addr = rewriter.create <cudaq::cc::ComputePtrOp>(
156168 loc, cudaq::cc::PointerType::get (i8Ty), buff, idx);
@@ -160,12 +172,12 @@ class ExpandDiscriminatePattern
160172 }
161173 }
162174
163- auto stdvecTy = cudaq::cc::StdvecType::get (rewriter. getContext (), i1Ty);
164- auto ptrArrI1Ty =
165- cudaq::cc::PointerType::get ( cudaq::cc::ArrayType::get (i1Ty));
166- auto buffCast = rewriter.create <cudaq::cc::CastOp>(loc, ptrArrI1Ty , buff);
167- rewriter.replaceOpWithNewOp <cudaq::cc::StdvecInitOp>(discOp, stdvecTy,
168- buffCast, totalToRead);
175+ auto ptrArrElemTy =
176+ cudaq::cc::PointerType::get ( cudaq::cc::ArrayType::get (elemTy));
177+ auto buffCast =
178+ rewriter.create <cudaq::cc::CastOp>(loc, ptrArrElemTy , buff);
179+ rewriter.replaceOpWithNewOp <cudaq::cc::StdvecInitOp>(
180+ discOp, stdvecResTy, buffCast, totalToRead);
169181 return success ();
170182 }
171183};
0 commit comments