Skip to content

Commit c79ab47

Browse files
paul0403tzunghanjuangritu-thombre99mehrdad2mdime10
authored
Update llvm version (#1752)
**Context:** We update the llvm version tagged by jax 0.6.0: ``` mhlo=617a9361d186199480c080c9e8c474a5e30c22d1 llvm=179d30f8c3fddd3c85056fd2b8e877a4a8513158 ``` We also update Enzyme to the latest version, which is 0.0.180, at commit `db0181320d6e425ee963bd496ed0d8dbb615be18` **Description of the Change:** Firstly, jax recently moved from the Google github organization to its own jax-ml organization. This means the urls, and the retrieval method for the underlying llvm and mhlo git commit tags, needs to be updated. (Thanks @mehrdad2m !) Now on to the actual changes. I will list the changes in increasing complexity. 1. The new enzyme cmake target is `EnzymeStatic-21` (from 20) 2. Enzyme works with a later llvm then our target, so it has some llvm intrinsics unknown to the one we are targeting. We patch them away. They do not concern us since they are all intrinsics for nvidia backends. 3. `applyPatternsAndFoldGreedily` is removed. Drop-in replacement is `applyPatternsGreedily`. llvm/llvm-project#104649, llvm/llvm-project#126701 4. ops with `CallOpInterface` must have two new optional attributes `arg_attrs` and `res_attrs` llvm/llvm-project#123176 5. `CallInterfaceCallable` objects now must be directly casted to the callee `SymbolRefAttr`, i.e. `callee.get<SymbolRefAttr>()` -> `cast<SymbolRefAttr>(callee)` llvm/llvm-project@35e8989 6. The `lookupOrCreateFn` family of functions now return `FailureOr<funcop>` instead of just `funcop`, so a `.value()` needs to be used to retrieve the underlying `funcop`. llvm/llvm-project@e84f6b6 7. The cpp api for `OneShotBufferizePassOptions` no longer needs complicated lambdas for the type converter options. They can be set with the `mlir::bufferization::LayoutMapOption::IdentityLayoutMap` options directly. 8. The individual `match` and `rewrite` methods in pattern rewrites are removed. Use the two-in-one `matchAndRewrite` instead. llvm/llvm-project#129861 9. For rewrite patterns with 1-to-N convertions, a new `macthAndRewrite` overload with `OneToNOpAdaptor` must be used. For us, this is only the `catalyst.list*` ops. llvm/llvm-project#116470 10. The lowering of `cf::AssertOp` to llvm was split from the overall`--covert-cf-to-llvm` pass. We need to manually call this separate pattern for cf.assert duriing quantum to llvm dialect lowering, where we also convert cf to llvm. https://github.com/llvm/llvm-project/pull/120431/files 11. The new mhlo depends on a [shardy](https://github.com/openxla/shardy) dialect. Shardy is built with bazel, not cmake. Building shardy ourselves would be very difficult (not having bazel in our build ecosystem is a hard constraint, cc @mlxd ), and also not necessary (we just use mhlo for their "standard" passes). We thus patch out all shardy components. 12. Three necessary passes were removed in mhlo: `mhlo-legalize-control-flow`, `mhlo-legalize-to-std`, `hlo-legalize-sort` tensorflow/mlir-hlo@4a640be#diff-ef0d7e30da19a396ba036405a9ef636f8b1be194618b0a90f4602671fc2ef34d tensorflow/mlir-hlo@2a5e267#diff-f8c7cb07b43593403e00e0dbf9983f0186b4eb70368cc99af3b924061f1ea46f - Alongside the removal of `mhlo-legalize-to-std`, the cmake target `MhloToStandard` was removed too. We simply patch them back for now. **For the above two points, note that there will be an overall migration to the stablehlo repo, as mhlo is sunseting. Therefore, spending too much time on this isn't necessary, so we just patch.** 13. The new pattern applicator (`applyPatternsGreedily`) is more aggressive in dead code elimination, and is eliminating dead `Value`s in the adjoint gradient method. The `nodealloc` function we generate for adjoint gradient lowering used to only return the qreg, not the expval result. This causes the expval op to be eliminated since it has no users. This further causes wrong gradient results, since the entire program, all ops included (regardless of dead or not), impacts the gradient through chain rule. To avoid this, we return the expval result as well. In doing this, we implicitly assume that differentiated qnodes can only return expval. Although this assumption is true and also restricted by frontend, ideally we should not have it hard coded. We leave this as a TODO for a future feature. 14. The old `--buffer-deallocation` pass is removed. Intended replacement is `--buffer-deallocation-pipeline`. This migration is very complicated. We simply add back the old buffer deallocation pass in the catalyst dialect as a util for now. We will revisit this in #1778 . mlir lit test updates: 1. `bufferization.to_tensor/memref` updated assembly format 2. gradient adjoint lowering test returns both qreg and expval 3. Some inverse unrealized conversion cast pairs are canceled by the new pattern rewriter. 4. `llvm.mlir.undef` is deprecated, use `llvm.mlir.poison` instead. llvm/llvm-project#125629 **Benefits:** Up to date with upstream versions. [sc-92017] --------- Co-authored-by: Tzung-Han Juang <[email protected]> Co-authored-by: Ritu Thombre <[email protected]> Co-authored-by: Mehrdad Malekmohammadi <[email protected]> Co-authored-by: Mehrdad Malek <[email protected]> Co-authored-by: David Ittah <[email protected]> Co-authored-by: Joey Carter <[email protected]>
1 parent 4228bf1 commit c79ab47

File tree

89 files changed

+2898
-453
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+2898
-453
lines changed

.dep-versions

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
11
# Always update the version check in catalyst.__init__ when changing the JAX version.
2-
3-
#############
4-
# We track mlir submodule versions from jax 0.4.32 for now
5-
# These are the earliest versions with complete upstream bufferization changes
6-
# Versions are retrieved from
7-
# python3 .github/workflows/set_dep_versions.py 0.4.32
8-
#############
9-
2+
# To update JAX version alongside compatible dependency tags, run the following script:
3+
# python3 .github/workflows/set_dep_versions.py {JAX_version}
104
jax=0.6.0
11-
mhlo=25b008569f413d76cfa8f481f3a84e82b89c47f4
12-
llvm=5f74671c85877e03622e8d308aee15ed73ccee7c
13-
enzyme=v0.0.149
5+
mhlo=617a9361d186199480c080c9e8c474a5e30c22d1
6+
llvm=179d30f8c3fddd3c85056fd2b8e877a4a8513158
7+
enzyme=v0.0.180
148

159
# Always remove custom PL/LQ versions before release.
1610

.github/workflows/build-wheel-linux-arm64.yaml

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,10 @@ jobs:
187187
run: |
188188
export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH
189189
190-
export TARGET_FILE=mlir/mlir-hlo/mhlo/transforms/CMakeLists.txt
191-
export PATCH_FILE=mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch
192-
if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi
190+
pushd $GITHUB_WORKSPACE/mlir/mlir-hlo
191+
git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch
192+
git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-add-back-necessary-passes.patch
193+
popd
193194
194195
cmake -S mlir/mlir-hlo -B $GITHUB_WORKSPACE/mhlo-build -G Ninja \
195196
-DCMAKE_BUILD_TYPE=Release \
@@ -215,14 +216,19 @@ jobs:
215216
if: steps.cache-enzyme-build.outputs.cache-hit != 'true'
216217
run: |
217218
export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH
219+
220+
export TARGET_FILE=mlir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
221+
export PATCH_FILE=mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
222+
if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi
223+
218224
cmake -S mlir/Enzyme/enzyme -B $GITHUB_WORKSPACE/enzyme-build -G Ninja \
219225
-DCMAKE_BUILD_TYPE=Release \
220226
-DLLVM_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/llvm" \
221227
-DENZYME_STATIC_LIB=ON \
222228
-DCMAKE_CXX_VISIBILITY_PRESET=default \
223229
-DCMAKE_CXX_FLAGS="-fuse-ld=lld"
224230
225-
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-20
231+
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-21
226232
227233
- name: Save Enzyme Build
228234
id: save-enzyme-build

.github/workflows/build-wheel-linux-x86_64.yaml

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,10 @@ jobs:
210210
run: |
211211
export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH
212212
213-
export TARGET_FILE=mlir/mlir-hlo/mhlo/transforms/CMakeLists.txt
214-
export PATCH_FILE=mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch
215-
if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi
213+
pushd $GITHUB_WORKSPACE/mlir/mlir-hlo
214+
git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch
215+
git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-add-back-necessary-passes.patch
216+
popd
216217
217218
cmake -S mlir/mlir-hlo -B $GITHUB_WORKSPACE/mhlo-build -G Ninja \
218219
-DCMAKE_BUILD_TYPE=Release \
@@ -238,14 +239,19 @@ jobs:
238239
if: steps.cache-enzyme-build.outputs.cache-hit != 'true'
239240
run: |
240241
export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH
242+
243+
export TARGET_FILE=mlir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
244+
export PATCH_FILE=mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
245+
if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi
246+
241247
cmake -S mlir/Enzyme/enzyme -B $GITHUB_WORKSPACE/enzyme-build -G Ninja \
242248
-DCMAKE_BUILD_TYPE=Release \
243249
-DLLVM_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/llvm" \
244250
-DENZYME_STATIC_LIB=ON \
245251
-DCMAKE_CXX_VISIBILITY_PRESET=default \
246252
-DCMAKE_CXX_FLAGS="-fuse-ld=lld"
247253
248-
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-20
254+
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-21
249255
250256
- name: Save Enzyme Build
251257
id: save-enzyme-build

.github/workflows/build-wheel-macos-arm64.yaml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,10 @@ jobs:
185185
run: |
186186
export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH
187187
188-
export TARGET_FILE=mlir/mlir-hlo/mhlo/transforms/CMakeLists.txt
189-
export PATCH_FILE=mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch
190-
if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi
188+
pushd $GITHUB_WORKSPACE/mlir/mlir-hlo
189+
git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch
190+
git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-add-back-necessary-passes.patch
191+
popd
191192
192193
cmake -S mlir/mlir-hlo -B $GITHUB_WORKSPACE/mhlo-build -G Ninja \
193194
-DCMAKE_BUILD_TYPE=Release \
@@ -212,13 +213,17 @@ jobs:
212213
- name: Build Enzyme
213214
if: steps.cache-enzyme-build.outputs.cache-hit != 'true'
214215
run: |
216+
export TARGET_FILE=mlir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
217+
export PATCH_FILE=mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
218+
if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi
219+
215220
cmake -S mlir/Enzyme/enzyme -B $GITHUB_WORKSPACE/enzyme-build -G Ninja \
216221
-DCMAKE_BUILD_TYPE=Release \
217222
-DLLVM_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/llvm" \
218223
-DENZYME_STATIC_LIB=ON \
219224
-DCMAKE_CXX_VISIBILITY_PRESET=default
220225
221-
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-20
226+
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-21
222227
223228
- name: Save Enzyme Build
224229
id: save-enzyme-build

.github/workflows/check-catalyst.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ jobs:
146146
sudo apt-get update
147147
sudo apt-get install -y python3 python3-pip cmake ninja-build clang lld
148148
python3 --version | grep ${{ needs.constants.outputs.primary_python_version }}
149-
python3 -m pip install numpy pybind11
149+
python3 -m pip install numpy pybind11 nanobind
150150
151151
- name: Build LLVM
152152
if: steps.cache-llvm-build.outputs.cache-hit != 'true'
@@ -194,7 +194,7 @@ jobs:
194194
uses: actions/cache@v4
195195
with:
196196
path: mhlo-build
197-
key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-ci-build-${{ matrix.compiler }}
197+
key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-ci-build-${{ matrix.compiler }}-0
198198

199199
- name: Get Cached LLVM Source
200200
id: cache-llvm-source
@@ -351,7 +351,7 @@ jobs:
351351
uses: actions/cache/restore@v4
352352
with:
353353
path: mhlo-build
354-
key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-ci-build-${{ matrix.compiler }}
354+
key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-ci-build-${{ matrix.compiler }}-0
355355
fail-on-cache-miss: true
356356

357357
- name: Get Cached Enzyme Source

.github/workflows/set_dep_versions.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@
3232
assert os.path.isfile(dep_versions_path)
3333
assert os.path.isfile(catalyst_init_path)
3434

35-
url = f"https://raw.githubusercontent.com/google/jax/jaxlib-v{jax_version}/WORKSPACE"
35+
url = f"https://raw.githubusercontent.com/jax-ml/jax/jax-v{jax_version}/WORKSPACE"
3636
response = requests.get(url)
3737
match = re.search(r'strip_prefix = "xla-([a-zA-Z0-9]*)"', response.text)
3838
if not match:
39-
url = f"https://raw.githubusercontent.com/google/jax/jaxlib-v{jax_version}/third_party/xla/workspace.bzl"
39+
url = f"https://raw.githubusercontent.com/jax-ml/jax/jax-v{jax_version}/third_party/xla/workspace.bzl"
4040
response = requests.get(url)
4141
match = re.search(r'XLA_COMMIT = "([a-zA-Z0-9]*)"', response.text)
4242
xla_commit = match.group(1)
@@ -67,21 +67,16 @@
6767
response = requests.get(url).json()
6868
hlo_commit = response["items"][0]["sha"]
6969

70-
existing_text = open(dep_versions_path, "r", encoding="UTF-8").read()
71-
match = re.search(r"enzyme=([a-zA-Z0-9]*)", existing_text)
72-
enzyme_commit = match.group(1)
73-
74-
with open(dep_versions_path, "w", encoding="UTF-8") as f:
75-
f.write(
76-
f"""\
77-
jax={jax_version}
78-
mhlo={hlo_commit}
79-
llvm={llvm_commit}
80-
enzyme={enzyme_commit}
81-
"""
82-
)
83-
8470
quote = '"'
85-
cmd = f"sed -i 's/_jaxlib_version = {quote}\([0-9.]\+\){quote}/_jaxlib_version = {quote}{jax_version}{quote}/g' {catalyst_init_path}"
86-
res = os.system(cmd)
87-
assert res == 0
71+
# Update each version using sed
72+
cmds = [
73+
f"sed -i '' 's/^jax=.*/jax={jax_version}/' {dep_versions_path}",
74+
f"sed -i '' 's/^mhlo=.*/mhlo={hlo_commit}/' {dep_versions_path}",
75+
f"sed -i '' 's/^llvm=.*/llvm={llvm_commit}/' {dep_versions_path}",
76+
# Update jaxlib version in __init__.py
77+
rf"sed -i '' 's/_jaxlib_version = {quote}\([0-9.]\+\){quote}/_jaxlib_version = {quote}{jax_version}{quote}/g' {catalyst_init_path}",
78+
]
79+
80+
for cmd in cmds:
81+
res = os.system(cmd)
82+
assert res == 0

doc/dev/transforms.rst

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -252,23 +252,22 @@ Note how the value ``%q2`` links the two operations together from definition ``(
252252
across several other instructions.
253253

254254
As seen in the `pattern rewriter documentation <https://mlir.llvm.org/docs/PatternRewriter/#defining-patterns>`_,
255-
a new rewrite pattern can be defined as a C++ class as follows, where we will focus on the ``match``
256-
and ``rewrite`` methods (refer to the link for the full class and up to date information):
255+
a new rewrite pattern can be defined as a C++ class as follows, where we will focus on the
256+
``matchAndRewrite`` method (refer to the link for the full class and up to date information):
257257

258258
.. code-block:: cpp
259259
260260
struct QubitUnitaryFusion : public OpRewritePattern<QubitUnitaryOp>
261261
{
262262
...
263263
264-
LogicalResult match(QubitUnitaryOp op) const override {
265-
// The ``match`` method returns ``success()`` if the pattern is a match, failure
266-
// otherwise.
267-
}
268-
269-
void rewrite(QubitUnitaryOp op, PatternRewriter &rewriter) {
270-
// The ``rewrite`` method performs mutations on the IR rooted at ``op`` using
271-
// the provided rewriter. All mutations must go through the provided rewriter.
264+
LogicalResult matchAndRewrite(QubitUnitaryOp op, PatternRewriter &rewriter) const override {
265+
// The `matchAndRewrite` method performs both the pattern matching and the mutation
266+
// on the IR rooted at `op` using the provided rewriter.
267+
// All mutations must go through the provided rewriter and IR mutation should only
268+
// take place after the match is deemed successful.
269+
// matchAndRewrite must return "success" if and only if the IR was modified.
270+
// The root operation is required to either be: updated in-place, replaced, or erased.
272271
}
273272
274273
...
@@ -286,11 +285,11 @@ the second is a list of qubits):
286285
287286
QubitUnitary(*, QubitUnitary(*, *))
288287
289-
Let's implement it in C++:
288+
Let's add the pattern-matching logic to the ``matchAndRewrite`` method:
290289

291290
.. code-block:: cpp
292291
293-
LogicalResult match(QubitUnitaryOp op) const override
292+
LogicalResult matchAndRewrite(QubitUnitaryOp op, PatternRewriter &rewriter) const override
294293
{
295294
ValueRange qbs = op.getInQubits();
296295
Operation *parent = qbs[0].getDefiningOp();
@@ -314,6 +313,9 @@ Let's implement it in C++:
314313
return failure();
315314
}
316315
316+
// Rewrite logic
317+
// ... We have matched the pattern, now rewrite the IR here
318+
317319
return success();
318320
}
319321
@@ -351,8 +353,8 @@ MLIR will automatically generate canonical ``get*`` methods for attributes like
351353
``out_qubits``, and ``matrix``. When in doubt it's best to have a look at the generated C++ files in
352354
the build folder, named ``QuantumOps.h.inc`` and ``QuantumOps.cpp.inc`` in this instance.
353355

354-
Alright, now that we have the matching part, let's implement the actual transformation via the
355-
``rewrite`` method. All we need to do is replace the original pattern with the following:
356+
Alright, now that we have the matching part, let's add the actual transformation to the
357+
``matchAndRewrite`` method. All we need to do is replace the original pattern with the following:
356358

357359
.. code-block::
358360
@@ -362,8 +364,13 @@ In C++ it will look as follows:
362364

363365
.. code-block:: cpp
364366
365-
void rewrite(QubitUnitaryOp op, PatternRewriter &rewriter) const override
367+
LogicalResult matchAndRewrite(QubitUnitaryOp op, PatternRewriter &rewriter) const override
366368
{
369+
370+
// Pattern matching logic
371+
// ... match the pattern
372+
373+
// Rewrite logic
367374
ValueRange qbs = op.getInQubits();
368375
QubitUnitaryOp parentOp = cast<QubitUnitaryOp>(qbs[0].getDefiningOp());
369376
@@ -410,11 +417,13 @@ In C++ it will look as follows:
410417
// The second unitary is not needed anymore
411418
// Whoever uses the second unitary, use the first one instead!
412419
op.replaceAllUsesWith(parentOp);
420+
421+
return success();
413422
}
414423
415424
When writing transformations, the rewriter is the most important tool we have. It can create new
416425
operations for us, delete others, or change the place in the IR where we are choosing to make
417-
changes (also called the insertion point). Let's have look at some of these elements:
426+
changes (also called the insertion point). Let's have a look at some of these elements:
418427

419428
- **Constructing new operations**:
420429

@@ -512,15 +521,15 @@ and other function operations, which themselves can contain other operations, an
512521
quantumPatterns.add<QubitUnitaryFusion>(ctx);
513522
514523
// Apply patterns in an iterative and greedy manner.
515-
if (failed(applyPatternsAndFoldGreedily(op, std::move(quantumPatterns)))) {
524+
if (failed(applyPatternsGreedily(op, std::move(quantumPatterns)))) {
516525
return signalPassFailure();
517526
}
518527
}
519528
};
520529
521530
To apply patterns we need a `pattern applicator <https://mlir.llvm.org/docs/PatternRewriter/#common-pattern-drivers>`_.
522531
There a few in MLIR but typically you can just use the greedy pattern rewrite driver
523-
(``applyPatternsAndFoldGreedily``), which will iterative over the IR and apply patterns until a
532+
(``applyPatternsGreedily``), which will iterative over the IR and apply patterns until a
524533
fixed point is reached.
525534

526535
.. note::
@@ -565,21 +574,30 @@ gradient ops that specify the finite-difference method, indicated via the ``"fd"
565574

566575
.. code-block:: cpp
567576
568-
LogicalResult FiniteDiffLowering::match(GradOp op)
577+
LogicalResult FiniteDiffLowering::matchAndRewrite(GradOp op, PatternRewriter &rewriter)
569578
{
570-
if (op.getMethod() == "fd")
571-
return success();
579+
// Pattern matching logic
580+
if (op.getMethod() != "fd")
581+
return failure();
572582
573-
return failure();
583+
// Rewrite logic
584+
// ...
585+
586+
return success();
574587
}
575588
576589
For the rewriting part we'll want to introduce a few new elements, such as looking up symbols
577590
(function names), creating new functions, and changing the insertion point.
578591

579592
.. code-block:: cpp
580593
581-
void FiniteDiffLowering::rewrite(GradOp op, PatternRewriter &rewriter)
594+
LogicalResult FiniteDiffLowering::matchAndRewrite(GradOp op, PatternRewriter &rewriter)
582595
{
596+
// Pattern matching logic
597+
if (op.getMethod() != "fd")
598+
return failure();
599+
600+
// Rewrite logic
583601
// First let's find the function the grad operation is referencing.
584602
func::FuncOp callee =
585603
SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(op, op.getCalleeAttr());
@@ -609,6 +627,8 @@ For the rewriting part we'll want to introduce a few new elements, such as looki
609627
// Populate the function body.
610628
populateFiniteDiffMethod(rewriter, op, gradFn);
611629
}
630+
631+
return success();
612632
}
613633
614634
Symbols are string references to IR objects, which rather than containing a physical reference or
@@ -711,18 +731,20 @@ Alright, our function should now look something like this:
711731
func.return %dx, %dy, %dz : f64, f64, f64
712732
}
713733
714-
Finally, we have to amend our rewrite function to invoke the new function we created and delete the
734+
Finally, we have to amend our ``matchAndRewrite`` function to invoke the new function we created and delete the
715735
``GradOp`` from the IR:
716736

717737
.. code-block:: cpp
718738
719-
void FiniteDiffLowering::rewrite(GradOp op, PatternRewriter &rewriter)
739+
LogicalResult FiniteDiffLowering::matchAndRewrite(GradOp op, PatternRewriter &rewriter)
720740
{
721741
...
722742
populateFiniteDiffMethod(rewriter, op, gradFn);
723743
}
724744
725745
rewriter.replaceOpWithNewOp<func::CallOp>(op, gradFn, op.getArgOperands());
746+
747+
return success();
726748
}
727749
728750
Note how we can create a new operation, take its results, and use those to replace another operation

0 commit comments

Comments
 (0)