Skip to content

Commit e0c9c64

Browse files
lucianopaztwiecki
authored andcommitted
Swap expand_dims of Periodic and WrappedPeriodic kernel full method
1 parent 5f29b25 commit e0c9c64

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

pymc/gp/cov.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -779,8 +779,8 @@ def full(self, X: TensorLike, Xs: Optional[TensorLike] = None) -> TensorVariable
779779
X, Xs = self._slice(X, Xs)
780780
if Xs is None:
781781
Xs = X
782-
f1 = pt.expand_dims(X, axis=(0,))
783-
f2 = pt.expand_dims(Xs, axis=(1,))
782+
f1 = pt.expand_dims(X, axis=(1,))
783+
f2 = pt.expand_dims(Xs, axis=(0,))
784784
r = np.pi * (f1 - f2) / self.period
785785
r2 = pt.sum(pt.square(pt.sin(r) / self.ls), 2)
786786
return self.full_from_distance(r2, squared=True)
@@ -946,8 +946,8 @@ def full(self, X: TensorLike, Xs: Optional[TensorLike] = None) -> TensorVariable
946946
X, Xs = self._slice(X, Xs)
947947
if Xs is None:
948948
Xs = X
949-
f1 = pt.expand_dims(X, axis=(0,))
950-
f2 = pt.expand_dims(Xs, axis=(1,))
949+
f1 = pt.expand_dims(X, axis=(1,))
950+
f2 = pt.expand_dims(Xs, axis=(0,))
951951
r = np.pi * (f1 - f2) / self.period
952952
r2 = pt.sum(pt.square(pt.sin(r) / self.cov_func.ls), 2)
953953
return self.cov_func.full_from_distance(r2, squared=True)

tests/gp/test_cov.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,3 +872,29 @@ def test_raises3(self):
872872
with pm.Model() as model:
873873
with pytest.raises(ValueError):
874874
B = pm.gp.cov.Coregion(1)
875+
876+
877+
@pytest.mark.parametrize(
878+
["kernel", "args"],
879+
[
880+
["Constant", (1.0,)],
881+
["WhiteNoise", (1.0,)],
882+
["ExpQuad", (1, 1.0)],
883+
["RatQuad", (1, 1.0, 1.0)],
884+
["Exponential", (1, 1.0)],
885+
["Matern12", (1, 1.0)],
886+
["Matern32", (1, 1.0)],
887+
["Matern52", (1, 1.0)],
888+
["Periodic", (1, 1.0, 1.0)],
889+
["Circular", (1, 1.0)],
890+
["Linear", (1, 1.0)],
891+
["Cosine", (1, 1.0)],
892+
["Polynomial", (1, 1.0, 1.0, 1.0)],
893+
["WrappedPeriodic", (pm.gp.cov.ExpQuad(1, 1.0), 1.0)],
894+
["Gibbs", (1, lambda x: pt.ones(x.shape))],
895+
],
896+
)
897+
def test_full_shape(kernel, args):
898+
X = np.arange(10)[:, None]
899+
Xs = np.arange(5)[:, None]
900+
assert tuple(getattr(pm.gp.cov, kernel)(*args).full(X, Xs).shape.eval()) == (len(X), len(Xs))

0 commit comments

Comments
 (0)