From f86a41ffb8cd438c7e22f04b74d178042c954ef6 Mon Sep 17 00:00:00 2001 From: Pham Nguyen Hung <97870091+HangenYuu@users.noreply.github.com> Date: Mon, 22 Jul 2024 08:51:15 +0000 Subject: [PATCH] Rewrite specifically for Sum and Prod to remove Join --- pytensor/tensor/rewriting/math.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 75dba82d97..0129740901 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -86,6 +86,7 @@ from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import pow as pt_pow from pytensor.tensor.math import sum as pt_sum +from pytensor.tensor.math import prod as pt_prod from pytensor.tensor.rewriting.basic import ( alloc_like, broadcasted_by, @@ -1754,6 +1755,30 @@ def local_reduce_broadcastable(fgraph, node): # -- in this case we can remove the reduction completely return [new_reduced.astype(odtype)] +@register_canonicalize +@register_uncanonicalize +@register_specialize +@node_rewriter([Sum, Prod]) +def local_useless_join_(fgraph, node): + """ + sum(join(tensor1, tensor2...)) => sum(sum(tensor) for tensor in tensors) + or + prod(join(tensor1, tensor2...)) => prod(prod(tensor) for tensor in tensors) + + """ + (node_inps,) = node.inputs + if node_inps.owner and isinstance(node_inps.owner.op, Join): + inpts = node_inps.owner.inputs[1:] + # This specific implementation would introduce a + # `MakeVector` into the graph, which would then + # be rewritten again with + # pytensor/tensor/rewriting/basic.py:local_sum_make_vector + # A similar rewrite must be created for `prod` + if isinstance(node.op, Sum): + return [pt_sum([pt_sum(inp) for inp in inpts])] + elif isinstance(node.op, Prod): + return [pt_prod([pt_prod(inp) for inp in inpts])] + @register_specialize @node_rewriter([Sum, Prod])