Skip to content

Commit 13949fe

Browse files
lucianopazricardoV94
authored andcommitted
Implement vectorized adstock transformations
1 parent 3e54549 commit 13949fe

File tree

3 files changed

+146
-66
lines changed

3 files changed

+146
-66
lines changed

pymc_marketing/mmm/delayed_saturated_mmm.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55

66
from pymc_marketing.mmm.base import MMM
77
from pymc_marketing.mmm.preprocessing import MaxAbsScaleChannels, MaxAbsScaleTarget
8-
from pymc_marketing.mmm.transformers import (
9-
geometric_adstock_vectorized,
10-
logistic_saturation,
11-
)
8+
from pymc_marketing.mmm.transformers import geometric_adstock, logistic_saturation
129
from pymc_marketing.mmm.validating import ValidateControlColumns
1310

1411

@@ -80,11 +77,12 @@ def build_model(
8077

8178
channel_adstock = pm.Deterministic(
8279
name="channel_adstock",
83-
var=geometric_adstock_vectorized(
80+
var=geometric_adstock(
8481
x=channel_data_,
8582
alpha=alpha,
8683
l_max=adstock_max_lag,
8784
normalize=True,
85+
axis=0,
8886
),
8987
dims=("date", "channel"),
9088
)

pymc_marketing/mmm/transformers.py

Lines changed: 79 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,70 @@
11
import pytensor.tensor as pt
2+
from pytensor.tensor.random.utils import params_broadcast_shapes
23

34

4-
def geometric_adstock(x, alpha: float = 0.0, l_max: int = 12, normalize: bool = False):
5+
def batched_convolution(x, w, axis: int = 0):
6+
"""Apply a 1D convolution in a vectorized way across multiple batch dimensions.
7+
8+
Parameters
9+
----------
10+
x :
11+
The array to convolve.
12+
w :
13+
The weight of the convolution. The last axis of ``w`` determines the number of steps
14+
to use in the convolution.
15+
axis : int
16+
The axis of ``x`` along witch to apply the convolution
17+
18+
Returns
19+
-------
20+
y :
21+
The result of convolving ``x`` with ``w`` along the desired axis. The shape of the
22+
result will match the shape of ``x`` up to broadcasting with ``w``. The convolved
23+
axis will show the results of left padding zeros to ``x`` while applying the
24+
convolutions.
25+
"""
26+
# We move the axis to the last dimension of the array so that it's easier to
27+
# reason about parameter broadcasting. We will move the axis back at the end
28+
orig_ndim = x.ndim
29+
axis = axis if axis >= 0 else orig_ndim + axis
30+
w = pt.as_tensor(w)
31+
x = pt.moveaxis(x, axis, -1)
32+
l_max = w.type.shape[-1]
33+
if l_max is None:
34+
try:
35+
l_max = w.shape[-1].eval()
36+
except Exception:
37+
pass
38+
# Get the broadcast shapes of x and w but ignoring their last dimension.
39+
# The last dimension of x is the "time" axis, which doesn't get broadcast
40+
# The last dimension of w is the number of time steps that go into the convolution
41+
x_shape, w_shape = params_broadcast_shapes([x.shape, w.shape], [1, 1])
42+
x = pt.broadcast_to(x, x_shape)
43+
w = pt.broadcast_to(w, w_shape)
44+
x_time = x.shape[-1]
45+
shape = (*x.shape, w.shape[-1])
46+
# Make a tensor with x at the different time lags needed for the convolution
47+
padded_x = pt.zeros(shape, dtype=x.dtype)
48+
if l_max is not None:
49+
for i in range(l_max):
50+
padded_x = pt.set_subtensor(
51+
padded_x[..., i:x_time, i], x[..., : x_time - i]
52+
)
53+
else: # pragma: no cover
54+
raise NotImplementedError(
55+
"At the moment, convolving with weight arrays that don't have a concrete shape "
56+
"at compile time is not supported."
57+
)
58+
# The convolution is treated as an element-wise product, that then gets reduced
59+
# along the dimension that represents the convolution time lags
60+
conv = pt.sum(padded_x * w[..., None, :], axis=-1)
61+
# Move the "time" axis back to where it was in the original x array
62+
return pt.moveaxis(conv, -1, axis + conv.ndim - orig_ndim)
63+
64+
65+
def geometric_adstock(
66+
x, alpha: float = 0.0, l_max: int = 12, normalize: bool = False, axis: int = 0
67+
):
568
"""Geometric adstock transformation.
669
770
Adstock with geometric decay assumes advertising effect peaks at the same
@@ -31,29 +94,19 @@ def geometric_adstock(x, alpha: float = 0.0, l_max: int = 12, normalize: bool =
3194
.. [1] Jin, Yuxue, et al. "Bayesian methods for media mix modeling
3295
with carryover and shape effects." (2017).
3396
"""
34-
cycles = [pt.concatenate([pt.zeros(i), x[: x.shape[0] - i]]) for i in range(l_max)]
35-
x_cycle = pt.stack(cycles)
36-
w = pt.as_tensor_variable([pt.power(alpha, i) for i in range(l_max)])
37-
w = w / pt.sum(w) if normalize else w
38-
return pt.dot(w, x_cycle)
39-
40-
41-
def geometric_adstock_vectorized(x, alpha, l_max: int = 12, normalize: bool = False):
42-
"""Vectorized geometric adstock transformation."""
43-
cycles = [
44-
pt.concatenate(tensor_list=[pt.zeros(shape=x.shape)[:i], x[: x.shape[0] - i]])
45-
for i in range(l_max)
46-
]
47-
x_cycle = pt.stack(cycles)
48-
x_cycle = pt.transpose(x=x_cycle, axes=[1, 2, 0])
49-
w = pt.as_tensor_variable([pt.power(alpha, i) for i in range(l_max)])
50-
w = pt.transpose(w)[None, ...]
51-
w = w / pt.sum(w, axis=2, keepdims=True) if normalize else w
52-
return pt.sum(pt.mul(x_cycle, w), axis=2)
97+
98+
w = pt.power(pt.as_tensor(alpha)[..., None], pt.arange(l_max, dtype=x.dtype))
99+
w = w / pt.sum(w, axis=-1, keepdims=True) if normalize else w
100+
return batched_convolution(x, w, axis=axis)
53101

54102

55103
def delayed_adstock(
56-
x, alpha: float = 0.0, theta: int = 0, l_max: int = 12, normalize: bool = False
104+
x,
105+
alpha: float = 0.0,
106+
theta: int = 0,
107+
l_max: int = 12,
108+
normalize: bool = False,
109+
axis: int = 0,
57110
):
58111
"""Delayed adstock transformation.
59112
@@ -83,31 +136,12 @@ def delayed_adstock(
83136
.. [1] Jin, Yuxue, et al. "Bayesian methods for media mix modeling
84137
with carryover and shape effects." (2017).
85138
"""
86-
cycles = [pt.concatenate([pt.zeros(i), x[: x.shape[0] - i]]) for i in range(l_max)]
87-
x_cycle = pt.stack(cycles)
88-
w = pt.as_tensor_variable(
89-
[pt.power(alpha, ((i - theta) ** 2)) for i in range(l_max)]
90-
)
91-
w = w / pt.sum(w) if normalize else w
92-
return pt.dot(w, x_cycle)
93-
94-
95-
def delayed_adstock_vectorized(
96-
x, alpha, theta, l_max: int = 12, normalize: bool = False
97-
):
98-
"""Delayed adstock transformation."""
99-
cycles = [
100-
pt.concatenate(tensor_list=[pt.zeros(shape=x.shape)[:i], x[: x.shape[0] - i]])
101-
for i in range(l_max)
102-
]
103-
x_cycle = pt.stack(cycles)
104-
x_cycle = pt.transpose(x=x_cycle, axes=[1, 2, 0])
105-
w = pt.as_tensor_variable(
106-
[pt.power(alpha, ((i - theta) ** 2)) for i in range(l_max)]
139+
w = pt.power(
140+
pt.as_tensor(alpha)[..., None],
141+
(pt.arange(l_max, dtype=x.dtype) - pt.as_tensor(theta)[..., None]) ** 2,
107142
)
108-
w = pt.transpose(w)[None, ...]
109-
w = w / pt.sum(w, axis=2, keepdims=True) if normalize else w
110-
return pt.sum(pt.mul(x_cycle, w), axis=2)
143+
w = w / pt.sum(w, axis=-1, keepdims=True) if normalize else w
144+
return batched_convolution(x, w, axis=axis)
111145

112146

113147
def logistic_saturation(x, lam: float = 0.5):

tests/mmm/test_transformers.py

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import numpy as np
2+
import pytensor
23
import pytensor.tensor as pt
34
import pytest
45
from pytensor.tensor.var import TensorVariable
56

67
from pymc_marketing.mmm.transformers import (
8+
batched_convolution,
79
delayed_adstock,
8-
delayed_adstock_vectorized,
910
geometric_adstock,
10-
geometric_adstock_vectorized,
1111
logistic_saturation,
1212
tanh_saturation,
1313
)
@@ -27,6 +27,58 @@ def dummy_design_matrix():
2727
)
2828

2929

30+
@pytest.fixture(
31+
scope="module", params=["ndarray", "TensorConstant", "TensorVariable"], ids=str
32+
)
33+
def convolution_inputs(request):
34+
x_val = np.ones((3, 4, 5))
35+
w_val = np.ones((2))
36+
if request.param == "ndarray":
37+
return x_val, w_val, None, None
38+
elif request.param == "TensorConstant":
39+
return pt.as_tensor_variable(x_val), pt.as_tensor_variable(w_val), None, None
40+
elif request.param == "TensorVariable":
41+
return (
42+
pt.dtensor3("x"),
43+
pt.specify_shape(pt.dvector("w"), w_val.shape),
44+
x_val,
45+
w_val,
46+
)
47+
48+
49+
@pytest.fixture(scope="module", params=[0, 1, -1])
50+
def convolution_axis(request):
51+
return request.param
52+
53+
54+
def test_batched_convolution(convolution_inputs, convolution_axis):
55+
x, w, x_val, w_val = convolution_inputs
56+
y = batched_convolution(x, w, convolution_axis)
57+
if x_val is None:
58+
y_val = y.eval()
59+
expected_shape = getattr(x, "value", x).shape
60+
else:
61+
y_val = pytensor.function([x, w], y)(x_val, w_val)
62+
expected_shape = x_val.shape
63+
assert y_val.shape == expected_shape
64+
y_val = np.moveaxis(y_val, convolution_axis, 0)
65+
x_val = np.moveaxis(
66+
x_val if x_val is not None else getattr(x, "value", x), convolution_axis, 0
67+
)
68+
assert np.allclose(y_val[0], x_val[0])
69+
assert np.allclose(y_val[1:], x_val[1:] + x_val[:-1])
70+
71+
72+
def test_batched_convolution_broadcasting():
73+
x_val = np.random.default_rng(42).normal(size=(3, 1, 5))
74+
x = pt.as_tensor_variable(x_val)
75+
w = pt.as_tensor_variable(np.ones((1, 1, 4, 2)))
76+
y = batched_convolution(x, w, axis=-1).eval()
77+
assert y.shape == (1, 3, 4, 5)
78+
assert np.allclose(y[..., 0], x_val[..., 0])
79+
assert np.allclose(y[..., 1:], x_val[..., 1:] + x_val[..., :-1])
80+
81+
3082
class TestsAdstockTransformers:
3183
def test_geometric_adstock_x_zero(self):
3284
x = np.zeros(shape=(100))
@@ -62,14 +114,12 @@ def test_delayed_adstock_x_zero(self):
62114
y = delayed_adstock(x=x, alpha=0.2, theta=2, l_max=4)
63115
np.testing.assert_array_equal(x=x, y=y.eval())
64116

65-
def test_geometric_adstock_vactorized(self, dummy_design_matrix):
117+
def test_geometric_adstock_vectorized(self, dummy_design_matrix):
66118
x = dummy_design_matrix.copy()
67119
x_tensor = pt.as_tensor_variable(x)
68120
alpha = [0.9, 0.33, 0.5, 0.1, 0.0]
69121
alpha_tensor = pt.as_tensor_variable(alpha)
70-
y_tensor = geometric_adstock_vectorized(
71-
x=x_tensor, alpha=alpha_tensor, l_max=12
72-
)
122+
y_tensor = geometric_adstock(x=x_tensor, alpha=alpha_tensor, l_max=12, axis=0)
73123
y = y_tensor.eval()
74124

75125
y_tensors = [
@@ -80,15 +130,15 @@ def test_geometric_adstock_vactorized(self, dummy_design_matrix):
80130
assert y.shape == x.shape
81131
np.testing.assert_almost_equal(actual=y, desired=ys, decimal=12)
82132

83-
def test_delayed_adstock_vactorized(self, dummy_design_matrix):
133+
def test_delayed_adstock_vectorized(self, dummy_design_matrix):
84134
x = dummy_design_matrix
85135
x_tensor = pt.as_tensor_variable(x)
86136
alpha = [0.9, 0.33, 0.5, 0.1, 0.0]
87137
alpha_tensor = pt.as_tensor_variable(alpha)
88138
theta = [0, 1, 2, 3, 4]
89139
theta_tensor = pt.as_tensor_variable(theta)
90-
y_tensor = delayed_adstock_vectorized(
91-
x=x_tensor, alpha=alpha_tensor, theta=theta_tensor, l_max=12
140+
y_tensor = delayed_adstock(
141+
x=x_tensor, alpha=alpha_tensor, theta=theta_tensor, l_max=12, axis=0
92142
)
93143
y = y_tensor.eval()
94144

@@ -220,7 +270,7 @@ def test_logistic_saturation_delayed_adstock_composition(
220270
assert z2_eval.max() <= 1
221271
assert z2_eval.min() >= 0
222272

223-
def test_geometric_adstock_vactorized_logistic_saturation(
273+
def test_geometric_adstock_vectorized_logistic_saturation(
224274
self, dummy_design_matrix
225275
):
226276
x = dummy_design_matrix.copy()
@@ -229,9 +279,7 @@ def test_geometric_adstock_vactorized_logistic_saturation(
229279
alpha_tensor = pt.as_tensor_variable(alpha)
230280
lam = [0.5, 1.0, 2.0, 3.0, 4.0]
231281
lam_tensor = pt.as_tensor_variable(lam)
232-
y_tensor = geometric_adstock_vectorized(
233-
x=x_tensor, alpha=alpha_tensor, l_max=12
234-
)
282+
y_tensor = geometric_adstock(x=x_tensor, alpha=alpha_tensor, l_max=12, axis=0)
235283
z_tensor = logistic_saturation(x=y_tensor, lam=lam_tensor)
236284
z = z_tensor.eval()
237285

@@ -246,7 +294,7 @@ def test_geometric_adstock_vactorized_logistic_saturation(
246294
assert zs.shape == x.shape
247295
np.testing.assert_almost_equal(actual=z, desired=zs, decimal=12)
248296

249-
def test_delayed_adstock_vactorized_logistic_saturation(self, dummy_design_matrix):
297+
def test_delayed_adstock_vectorized_logistic_saturation(self, dummy_design_matrix):
250298
x = dummy_design_matrix.copy()
251299
x_tensor = pt.as_tensor_variable(x)
252300
alpha = [0.9, 0.33, 0.5, 0.1, 0.0]
@@ -255,8 +303,8 @@ def test_delayed_adstock_vactorized_logistic_saturation(self, dummy_design_matri
255303
theta_tensor = pt.as_tensor_variable(theta)
256304
lam = [0.5, 1.0, 2.0, 3.0, 4.0]
257305
lam_tensor = pt.as_tensor_variable(lam)
258-
y_tensor = delayed_adstock_vectorized(
259-
x=x_tensor, alpha=alpha_tensor, theta=theta_tensor, l_max=12
306+
y_tensor = delayed_adstock(
307+
x=x_tensor, alpha=alpha_tensor, theta=theta_tensor, l_max=12, axis=0
260308
)
261309
z_tensor = logistic_saturation(x=y_tensor, lam=lam_tensor)
262310
z = z_tensor.eval()

0 commit comments

Comments
 (0)