We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 660ec3d commit 13e8181Copy full SHA for 13e8181
torch/fx/experimental/symbolic_shapes.py
@@ -2930,8 +2930,11 @@ def is_dim(src):
2930
def get_expression(tensor_dim_src):
2931
fake = placeholders[source_index[tensor_dim_src.base.name()]]
2932
symint = fake.shape[tensor_dim_src.idx]
2933
- assert isinstance(symint, torch.SymInt)
2934
- return symint.node.expr
+ if isinstance(symint, torch.SymInt):
+ return symint.node.expr
2935
+ else:
2936
+ assert type(symint) is int, f"Expected int, got {type(symint)}"
2937
+ return symint
2938
2939
for src1, src2 in equalities_inputs.source_pairs:
2940
expr1, expr2 = get_expression(src1), get_expression(src2)
0 commit comments