Skip to content

Recognize dot from naive sum of broadcasted muls #864

Open
@ricardoV94

Description

@ricardoV94

Description

Brought up in #858

import pytensor
import pytensor.tensor as pt

a = pt.matrix("a", shape=(200, 300))
b = pt.matrix("b", shape=(300, 400))
dot = (a[:, :, None] * b).sum(1)

fn = pytensor.function([a, b], dot)
pytensor.dprint(fn, print_type=True)
# Sum{axis=1} [id A] <Matrix(float64, shape=(200, 400))> 3
#  └─ Mul [id B] <Tensor3(float64, shape=(200, 300, 400))> 2
#     ├─ ExpandDims{axis=2} [id C] <Tensor3(float64, shape=(200, 300, 1))> 1
#     │  └─ a [id D] <Matrix(float64, shape=(200, 300))>
#     └─ ExpandDims{axis=0} [id E] <Tensor3(float64, shape=(1, 300, 400))> 0
#        └─ b [id F] <Matrix(float64, shape=(300, 400))>

fn_dot = pytensor.function([a, b], a @ b)
print(); pytensor.dprint(fn_dot, print_type=True)
# Dot22 [id A] <Matrix(float64, shape=(200, 400))> 0
#  ├─ a [id B] <Matrix(float64, shape=(200, 300))>
#  └─ b [id C] <Matrix(float64, shape=(300, 400))>

a_test = np.random.normal(size=a.type.shape)
b_test = np.random.normal(size=b.type.shape)
np.testing.assert_allclose(fn(a_test, b_test), fn_dot(a_test, b_test))

%timeit fn(a_test, b_test)  # 70.9 ms ± 1.29 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit fn_dot(a_test, b_test)  # 861 µs ± 148 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions