Skip to content

Commit 3446bd4

Browse files
Add associative scan (#30)
* First attempt at associative scan * vmap binary operator * Fix mathematics * Remove double expm_vp * Amend test * Flexible scan * Split tests * Add diffrax associative scan comparison * associative scan notebook polishing * add associative scan runtime plot * polish notebook * add same key handling for associative_scan * improve docstring * add underscore to function name and remove docstringÒ * fix tests * Cleaner associative_scan handling of b * Move sample and prob to top of file * Fix solve * Generalise expm_vp --------- Co-authored-by: KaelanDt <[email protected]>
1 parent cb85b64 commit 3446bd4

File tree

4 files changed

+405
-128
lines changed

4 files changed

+405
-128
lines changed

examples/associative_scan.ipynb

Lines changed: 251 additions & 0 deletions
Large diffs are not rendered by default.

tests/test_sampler.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,27 @@
55

66

77
def test_sample_array_input():
8+
jax.config.update("jax_enable_x64", True)
89
key = jax.random.PRNGKey(0)
910
dim = 2
1011
dt = 0.1
1112
ts = jnp.arange(0, 10_000, dt)
1213

13-
A = jnp.array([[3, 2], [2, 4.0]])
14-
b, x0 = jnp.zeros(dim), jnp.zeros(dim)
14+
# Add some noise to the time points to make the timesteps different
15+
ts += jax.random.uniform(key, (ts.shape[0],)) * dt
16+
ts = ts.sort()
17+
18+
A = jnp.array([[3, 2.5], [2, 4.0]])
19+
b = jax.random.normal(jax.random.PRNGKey(1), (dim,))
20+
x0 = jax.random.normal(jax.random.PRNGKey(2), (dim,))
1521
D = 2 * jnp.eye(dim)
1622

17-
samples = thermox.sample(key, ts, x0, A, b, D)
23+
samples = thermox.sample(key, ts, x0, A, b, D, associative_scan=False)
1824

1925
samp_cov = jnp.cov(samples.T)
2026
samp_mean = jnp.mean(samples.T, axis=1)
2127
assert jnp.allclose(A @ samp_cov, jnp.eye(2), atol=1e-1)
2228
assert jnp.allclose(samp_mean, b, atol=1e-1)
29+
30+
samples_as = thermox.sample(key, ts, x0, A, b, D, associative_scan=True)
31+
assert jnp.allclose(samples, samples_as, atol=1e-6)

thermox/prob.py

Lines changed: 47 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -8,74 +8,7 @@
88
ProcessedDriftMatrix,
99
ProcessedDiffusionMatrix,
1010
)
11-
12-
13-
def log_prob_identity_diffusion(
14-
ts: Array,
15-
xs: Array,
16-
A: Array | ProcessedDriftMatrix,
17-
b: Array,
18-
) -> float:
19-
"""Calculates log probability of samples from the Ornstein-Uhlenbeck process,
20-
defined as:
21-
22-
dx = - A * (x - b) dt + dW
23-
24-
by using exact diagonalization.
25-
26-
Assumes x(t_0) is given deterministically.
27-
28-
Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2).
29-
30-
Args:
31-
ts: Times at which samples are collected. Includes time for x0.
32-
xs: Initial state of the process.
33-
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
34-
b: Drift displacement vector.
35-
Returns:
36-
Scalar log probability of given xs.
37-
"""
38-
if isinstance(A, Array):
39-
A = preprocess_drift_matrix(A)
40-
41-
def expm_vp(v, dt):
42-
out = A.eigvecs_inv @ v
43-
out = jnp.exp(-A.eigvals * dt) * out
44-
out = A.eigvecs @ out
45-
return out.real
46-
47-
def transition_mean(y, dt):
48-
return b + expm_vp(y - b, dt)
49-
50-
def transition_cov_sqrt_inv_vp(v, dt):
51-
diag = ((1 - jnp.exp(-2 * A.sym_eigvals * dt)) / (2 * A.sym_eigvals)) ** 0.5
52-
diag = jnp.where(diag < 1e-20, 1e-20, diag)
53-
out = A.sym_eigvecs.T @ v
54-
out = out / diag
55-
return out.real
56-
57-
def transition_cov_log_det(dt):
58-
diag = (1 - jnp.exp(-2 * A.sym_eigvals * dt)) / (2 * A.sym_eigvals)
59-
diag = jnp.where(diag < 1e-20, 1e-20, diag)
60-
return jnp.sum(jnp.log(diag))
61-
62-
def logpt(yt, y0, dt):
63-
mean = transition_mean(y0, dt)
64-
diff_val = transition_cov_sqrt_inv_vp(yt - mean, dt)
65-
return (
66-
-jnp.dot(diff_val, diff_val) / 2
67-
- transition_cov_log_det(dt) / 2
68-
- jnp.log(2 * jnp.pi) * (yt.shape[0] / 2)
69-
)
70-
71-
log_prob_val = fori_loop(
72-
1,
73-
len(ts),
74-
lambda i, val: val + logpt(xs[i], xs[i - 1], ts[i] - ts[i - 1]),
75-
0.0,
76-
)
77-
78-
return log_prob_val.real
11+
from thermox.sampler import expm_vp
7912

8013

8114
def log_prob(
@@ -105,7 +38,7 @@ def log_prob(
10538
ts: Times at which samples are collected. Includes time for x0.
10639
xs: Initial state of the process.
10740
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
108-
Note : If a thermox.ProcessedDriftMatrix instance is used as input,
41+
Note: If a thermox.ProcessedDriftMatrix instance is used as input,
10942
must be the transformed drift matrix, A_y, given by thermox.preprocess,
11043
not thermox.utils.preprocess_drift_matrix.
11144
b: Drift displacement vector.
@@ -122,3 +55,48 @@ def log_prob(
12255

12356
D_sqrt_inv_log_det = jnp.log(jnp.linalg.det(D.sqrt_inv))
12457
return log_prob_ys + D_sqrt_inv_log_det * (len(ts) - 1)
58+
59+
60+
def transition_cov_sqrt_inv_vp(A, v, dt):
61+
diag = ((1 - jnp.exp(-2 * A.sym_eigvals * dt)) / (2 * A.sym_eigvals)) ** 0.5
62+
diag = jnp.where(diag < 1e-20, 1e-20, diag)
63+
out = A.sym_eigvecs.T @ v
64+
out = out / diag
65+
return out.real
66+
67+
68+
def transition_cov_log_det(A, dt):
69+
diag = (1 - jnp.exp(-2 * A.sym_eigvals * dt)) / (2 * A.sym_eigvals)
70+
diag = jnp.where(diag < 1e-20, 1e-20, diag)
71+
return jnp.sum(jnp.log(diag))
72+
73+
74+
def log_prob_identity_diffusion(
75+
ts: Array,
76+
xs: Array,
77+
A: Array | ProcessedDriftMatrix,
78+
b: Array,
79+
) -> float:
80+
if isinstance(A, Array):
81+
A = preprocess_drift_matrix(A)
82+
83+
def transition_mean(y, dt):
84+
return b + expm_vp(A, y - b, dt)
85+
86+
def logpt(yt, y0, dt):
87+
mean = transition_mean(y0, dt)
88+
diff_val = transition_cov_sqrt_inv_vp(A, yt - mean, dt)
89+
return (
90+
-jnp.dot(diff_val, diff_val) / 2
91+
- transition_cov_log_det(A, dt) / 2
92+
- jnp.log(2 * jnp.pi) * (yt.shape[0] / 2)
93+
)
94+
95+
log_prob_val = fori_loop(
96+
1,
97+
len(ts),
98+
lambda i, val: val + logpt(xs[i], xs[i - 1], ts[i] - ts[i - 1]),
99+
0.0,
100+
)
101+
102+
return log_prob_val.real

thermox/sampler.py

Lines changed: 95 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
from functools import partial
12
import jax
23
import jax.numpy as jnp
3-
from jax.lax import scan
44
from jax import Array
55

66
from thermox.utils import (
@@ -11,108 +11,147 @@
1111
)
1212

1313

14-
def sample_identity_diffusion(
14+
def sample(
1515
key: Array,
1616
ts: Array,
1717
x0: Array,
1818
A: Array | ProcessedDriftMatrix,
1919
b: Array,
20+
D: Array | ProcessedDiffusionMatrix,
21+
associative_scan: bool = True,
2022
) -> Array:
2123
"""Collects samples from the Ornstein-Uhlenbeck process, defined as:
2224
23-
dx = - A * (x - b) dt + dW
25+
dx = - A * (x - b) dt + sqrt(D) dW
2426
2527
by using exact diagonalization.
2628
27-
Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2)
29+
Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2),
2830
where T=len(ts).
2931
32+
If associative_scan=True then jax.lax.associative_scan is used which will run in
33+
time O((T/p + log(T)) * d^2) on a GPU/TPU with p cores, still with
34+
O(d^3) preprocessing.
35+
36+
By default, this function does the preprocessing on A and D before the evaluation.
37+
However, the preprocessing can be done externally using thermox.preprocess
38+
the output of which can be used as A and D here, this will skip the preprocessing.
39+
3040
Args:
3141
key: Jax PRNGKey.
3242
ts: Times at which samples are collected. Includes time for x0.
3343
x0: Initial state of the process.
3444
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
45+
Note: If a thermox.ProcessedDriftMatrix instance is used as input,
46+
must be the transformed drift matrix, A_y, given by thermox.preprocess,
47+
not thermox.utils.preprocess_drift_matrix.
3548
b: Drift displacement vector.
49+
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
50+
associative_scan: If True, uses jax.lax.associative_scan.
3651
3752
Returns:
3853
Array-like, desired samples.
3954
shape: (len(ts), ) + x0.shape
4055
"""
56+
A_y, D = handle_matrix_inputs(A, D)
57+
58+
y0 = D.sqrt_inv @ x0
59+
b_y = D.sqrt_inv @ b
60+
ys = sample_identity_diffusion(key, ts, y0, A_y, b_y, associative_scan)
61+
return jax.vmap(jnp.matmul, in_axes=(None, 0))(D.sqrt, ys)
4162

63+
64+
def sample_identity_diffusion(
65+
key: Array,
66+
ts: Array,
67+
x0: Array,
68+
A: Array | ProcessedDriftMatrix,
69+
b: Array,
70+
associative_scan: bool = True,
71+
) -> Array:
72+
if associative_scan:
73+
return _sample_identity_diffusion_associative_scan(key, ts, x0, A, b)
74+
else:
75+
return _sample_identity_diffusion_scan(key, ts, x0, A, b)
76+
77+
78+
def expm_vp(A, v, dt):
79+
out = A.eigvecs_inv @ v
80+
out = jnp.exp(-A.eigvals * dt) * out
81+
out = A.eigvecs @ out
82+
return out.real
83+
84+
85+
def transition_cov_sqrt_vp(A, v, dt):
86+
diag = ((1 - jnp.exp(-2 * A.sym_eigvals * dt)) / (2 * A.sym_eigvals)) ** 0.5
87+
out = diag * v
88+
out = A.sym_eigvecs @ out
89+
return out.real
90+
91+
92+
def _sample_identity_diffusion_scan(
93+
key: Array,
94+
ts: Array,
95+
x0: Array,
96+
A: Array | ProcessedDriftMatrix,
97+
b: Array,
98+
) -> Array:
4299
if isinstance(A, Array):
43100
A = preprocess_drift_matrix(A)
44101

45-
def expm_vp(v, dt):
46-
out = A.eigvecs_inv @ v
47-
out = jnp.exp(-A.eigvals * dt) * out
48-
out = A.eigvecs @ out
49-
return out.real
50-
51102
def transition_mean(x, dt):
52-
return b + expm_vp(x - b, dt)
103+
return b + expm_vp(A, x - b, dt)
53104

54-
def transition_cov_sqrt_vp(v, dt):
55-
diag = ((1 - jnp.exp(-2 * A.sym_eigvals * dt)) / (2 * A.sym_eigvals)) ** 0.5
56-
out = diag * v
57-
out = A.sym_eigvecs @ out
58-
return out.real
105+
def next_x(x, dt, rv):
106+
return transition_mean(x, dt) + transition_cov_sqrt_vp(A, rv, dt)
59107

60-
def next_x(x, dt, tkey):
61-
randv = jax.random.normal(tkey, shape=x.shape)
62-
return transition_mean(x, dt) + transition_cov_sqrt_vp(randv, dt)
63-
64-
def scan_body(x_and_key, dt):
65-
x, rk = x_and_key
66-
rk, rk_use = jax.random.split(rk)
67-
x = next_x(x, dt, rk_use)
68-
return (x, rk), x
108+
def scan_body(carry, dt_and_rv):
109+
x = carry
110+
dt, rv = dt_and_rv
111+
new_x = next_x(x, dt, rv)
112+
return new_x, new_x
69113

70114
dts = jnp.diff(ts)
115+
gauss_samps = jax.random.normal(key, (len(dts),) + x0.shape)
116+
117+
# Stack dts and gauss_samps along a new axis
118+
dt_and_rv = (dts, gauss_samps)
71119

72-
xs = scan(scan_body, (x0, key), dts)[1]
120+
_, xs = jax.lax.scan(scan_body, x0, dt_and_rv)
73121
xs = jnp.concatenate([jnp.expand_dims(x0, axis=0), xs], axis=0)
74122
return xs
75123

76124

77-
def sample(
125+
def _sample_identity_diffusion_associative_scan(
78126
key: Array,
79127
ts: Array,
80128
x0: Array,
81129
A: Array | ProcessedDriftMatrix,
82130
b: Array,
83-
D: Array | ProcessedDiffusionMatrix,
84131
) -> Array:
85-
"""Collects samples from the Ornstein-Uhlenbeck process, defined as:
86-
87-
dx = - A * (x - b) dt + sqrt(D) dW
132+
if isinstance(A, Array):
133+
A = preprocess_drift_matrix(A)
88134

89-
by using exact diagonalization.
135+
dts = jnp.diff(ts)
90136

91-
Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2),
92-
where T=len(ts).
137+
# transition_mean(x, dt) = b + expm_vp(A, x - b, dt)
93138

94-
By default, this function does the preprocessing on A and D before the evaluation.
95-
However, the preprocessing can be done externally using thermox.preprocess
96-
the output of which can be used as A and D here, this will skip the preprocessing.
139+
gauss_samps = jax.random.normal(key, (len(dts),) + x0.shape)
140+
noise_terms = jax.vmap(lambda v, dt: transition_cov_sqrt_vp(A, v, dt))(
141+
gauss_samps, dts
142+
)
97143

98-
Args:
99-
key: Jax PRNGKey.
100-
ts: Times at which samples are collected. Includes time for x0.
101-
x0: Initial state of the process.
102-
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
103-
Note : If a thermox.ProcessedDriftMatrix instance is used as input,
104-
must be the transformed drift matrix, A_y, given by thermox.preprocess,
105-
not thermox.utils.preprocess_drift_matrix.
106-
b: Drift displacement vector.
107-
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
144+
@partial(jax.vmap, in_axes=(0, 0))
145+
def binary_associative_operator(elem_a, elem_b):
146+
t_a, x_a = elem_a
147+
t_b, x_b = elem_b
148+
return t_a + t_b, expm_vp(A, x_a, t_b) + x_b
108149

109-
Returns:
110-
Array-like, desired samples.
111-
shape: (len(ts), ) + x0.shape
112-
"""
113-
A_y, D = handle_matrix_inputs(A, D)
150+
scan_times = jnp.concatenate([ts[:1], dts], dtype=float) # [t0, dt1, dt2, ...]
151+
scan_input_values = jnp.concatenate(
152+
[x0[None] - b, noise_terms], axis=0
153+
) # Shift input by b
154+
scan_elems = (scan_times, scan_input_values)
114155

115-
y0 = D.sqrt_inv @ x0
116-
b_y = D.sqrt_inv @ b
117-
ys = sample_identity_diffusion(key, ts, y0, A_y, b_y)
118-
return jax.vmap(jnp.matmul, in_axes=(None, 0))(D.sqrt, ys)
156+
scan_output = jax.lax.associative_scan(binary_associative_operator, scan_elems)
157+
return scan_output[1] + b # Shift back by b

0 commit comments

Comments
 (0)