-
Notifications
You must be signed in to change notification settings - Fork 617
[TorchToLinalg] Support lowering AtenReplicationPad3d to linalg #4233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py
Show resolved
Hide resolved
Could you please review this @zjgarvey ? |
Hi @zjgarvey, can you please review this ? |
Hey, @vinitdeodhar . Would you mind addressing the ci failure first? Let me know if you have trouble debugging and I can help you out. |
Hi @zjgarvey the ci failures do not seem to be related to the change and affect other PRs submitted at the time too. I dont have access rights to rerun the job and try again. Here is the error thrown: |
Can you sync the branch with main so we can re-run? |
Thanks ! I synced the branch that it resolved the failures |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reminder ping. This mostly looks good to me, just some nit comments.
// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0 | ||
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1 | ||
// CHECK-DAG: %[[INT3:.*]] = torch.constant.int 3 | ||
// CHECK: %[[PAD_LIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT1]], %[[INT3]], %[[INT1]], %[[INT0]], %[[INT3]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you should check these (they just end up getting DCE'd anyway).
SmallVector<Value> slices(tileWidth, slice); | ||
return rewriter.create<tensor::ConcatOp>(loc, dimension, slices); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm sure it folds later in the compiler, but would simple enough to add a check for tileWidth == 1
so we don't generate redundant concats (e.g. as seen in the lit test).
Add support of AtenReplicationPad3d in torch dialect and lowering it to linalg backend
AtenReplicationPad3d is lowered using a sequence of tensor.extract_slice and tensor.concat operations consistent with the existing lowerings of AtenReplicationPad1d and AtenReplicationPad2d for the linalg backend