Skip to content

Commit 7cb078d

Browse files
committed
Tag rewrites that make shape assumptions
1 parent 48d90b2 commit 7cb078d

File tree

4 files changed

+27
-31
lines changed

4 files changed

+27
-31
lines changed

pytensor/configdefaults.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -682,25 +682,7 @@ def add_traceback_configvars():
682682

683683

684684
def add_experimental_configvars():
685-
config.add(
686-
"experimental__local_alloc_elemwise",
687-
"DEPRECATED: If True, enable the experimental"
688-
" optimization local_alloc_elemwise."
689-
" Generates error if not True. Use"
690-
" optimizer_excluding=local_alloc_elemwise"
691-
" to disable.",
692-
BoolParam(True),
693-
in_c_key=False,
694-
)
695-
696-
# False could make the graph faster but not as safe.
697-
config.add(
698-
"experimental__local_alloc_elemwise_assert",
699-
"When the local_alloc_elemwise is applied, add"
700-
" an assert to highlight shape errors.",
701-
BoolParam(True),
702-
in_c_key=False,
703-
)
685+
return
704686

705687

706688
def add_error_and_warning_configvars():

pytensor/tensor/rewriting/basic.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def local_scalar_tensor_scalar(fgraph, node):
257257
return [s]
258258

259259

260-
@register_specialize("local_alloc_elemwise")
260+
@register_specialize("shape_unsafe")
261261
@node_rewriter([Elemwise])
262262
def local_elemwise_alloc(fgraph, node):
263263
r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s.
@@ -378,7 +378,7 @@ def dimshuffled_alloc(i):
378378
return ret
379379

380380

381-
@register_canonicalize
381+
@register_canonicalize("shape_unsafe")
382382
@node_rewriter([Elemwise])
383383
def local_fill_sink(fgraph, node):
384384
"""
@@ -429,8 +429,8 @@ def local_fill_sink(fgraph, node):
429429
return replacements
430430

431431

432-
@register_specialize
433-
@register_stabilize
432+
@register_specialize("shape_unsafe")
433+
@register_stabilize("shape_unsafe")
434434
@node_rewriter([fill])
435435
def local_fill_to_alloc(fgraph, node):
436436
r"""Remove `fill`\s or replace them with `Alloc`\s.
@@ -480,8 +480,8 @@ def local_fill_to_alloc(fgraph, node):
480480
)
481481

482482

483-
@register_canonicalize("fast_compile")
484-
@register_useless
483+
@register_canonicalize("fast_compile", "shape_unsafe")
484+
@register_useless("shape_unsafe")
485485
@node_rewriter([fill])
486486
def local_useless_fill(fgraph, node):
487487
"""fill(s,v) -> v
@@ -501,10 +501,10 @@ def local_useless_fill(fgraph, node):
501501
return [v]
502502

503503

504-
@register_specialize
505-
@register_stabilize
506-
@register_canonicalize
507-
@register_useless
504+
@register_specialize("shape_unsafe")
505+
@register_stabilize("shape_unsafe")
506+
@register_canonicalize("shape_unsafe")
507+
@register_useless("shape_unsafe")
508508
@node_rewriter([Alloc])
509509
def local_useless_alloc(fgraph, node):
510510
"""

pytensor/tensor/rewriting/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,7 +1176,7 @@ def mul_calculate(num, denum, aslist=False, out_type=None):
11761176
local_mul_canonizer = AlgebraicCanonizer(
11771177
mul, true_div, reciprocal, mul_calculate, False
11781178
)
1179-
register_canonicalize(local_mul_canonizer, name="local_mul_canonizer")
1179+
register_canonicalize(local_mul_canonizer, "shape_unsafe", name="local_mul_canonizer")
11801180

11811181

11821182
@register_canonicalize
@@ -2494,7 +2494,7 @@ def add_calculate(num, denum, aslist=False, out_type=None):
24942494
)
24952495

24962496

2497-
register_canonicalize(local_add_canonizer, name="local_add_canonizer")
2497+
register_canonicalize(local_add_canonizer, "shape_unsafe", name="local_add_canonizer")
24982498

24992499

25002500
def distribute_greedy(pos_pairs, neg_pairs, num, denum, out_type, minscore=0):

tests/tensor/rewriting/test_basic.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1933,3 +1933,17 @@ def test_misc(self):
19331933
x_val = np.random.random((1, 5)).astype(self.dtype)
19341934
exp_res = np.broadcast_to(x_val, (5, 5))[..., None] + y_val
19351935
assert np.array_equal(func(y_val, x_val), exp_res)
1936+
1937+
1938+
def test_shape_unsafe_tag():
1939+
mode = get_mode("FAST_RUN")
1940+
x = vector("x")
1941+
y = vector("y")
1942+
out = x * y / y
1943+
1944+
fn = function([x, y], out, mode=mode)
1945+
np.testing.assert_equal(fn([0, 1], [2, 3, 4]), [0, 1])
1946+
1947+
fn = function([x, y], out, mode=mode.excluding("shape_unsafe"))
1948+
with pytest.raises(ValueError):
1949+
fn([0, 1], [2, 3, 4]), [0, 1]

0 commit comments

Comments
 (0)