|
| 1 | +from functools import partial |
1 | 2 | import jax
|
2 | 3 | import jax.numpy as jnp
|
3 |
| -from jax.lax import scan |
4 | 4 | from jax import Array
|
5 | 5 |
|
6 | 6 | from thermox.utils import (
|
|
11 | 11 | )
|
12 | 12 |
|
13 | 13 |
|
14 |
| -def sample_identity_diffusion( |
| 14 | +def sample( |
15 | 15 | key: Array,
|
16 | 16 | ts: Array,
|
17 | 17 | x0: Array,
|
18 | 18 | A: Array | ProcessedDriftMatrix,
|
19 | 19 | b: Array,
|
| 20 | + D: Array | ProcessedDiffusionMatrix, |
| 21 | + associative_scan: bool = True, |
20 | 22 | ) -> Array:
|
21 | 23 | """Collects samples from the Ornstein-Uhlenbeck process, defined as:
|
22 | 24 |
|
23 |
| - dx = - A * (x - b) dt + dW |
| 25 | + dx = - A * (x - b) dt + sqrt(D) dW |
24 | 26 |
|
25 | 27 | by using exact diagonalization.
|
26 | 28 |
|
27 |
| - Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2) |
| 29 | + Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2), |
28 | 30 | where T=len(ts).
|
29 | 31 |
|
| 32 | + If associative_scan=True then jax.lax.associative_scan is used which will run in |
| 33 | + time O((T/p + log(T)) * d^2) on a GPU/TPU with p cores, still with |
| 34 | + O(d^3) preprocessing. |
| 35 | +
|
| 36 | + By default, this function does the preprocessing on A and D before the evaluation. |
| 37 | + However, the preprocessing can be done externally using thermox.preprocess |
| 38 | + the output of which can be used as A and D here, this will skip the preprocessing. |
| 39 | +
|
30 | 40 | Args:
|
31 | 41 | key: Jax PRNGKey.
|
32 | 42 | ts: Times at which samples are collected. Includes time for x0.
|
33 | 43 | x0: Initial state of the process.
|
34 | 44 | A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
|
| 45 | + Note: If a thermox.ProcessedDriftMatrix instance is used as input, |
| 46 | + must be the transformed drift matrix, A_y, given by thermox.preprocess, |
| 47 | + not thermox.utils.preprocess_drift_matrix. |
35 | 48 | b: Drift displacement vector.
|
| 49 | + D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). |
| 50 | + associative_scan: If True, uses jax.lax.associative_scan. |
36 | 51 |
|
37 | 52 | Returns:
|
38 | 53 | Array-like, desired samples.
|
39 | 54 | shape: (len(ts), ) + x0.shape
|
40 | 55 | """
|
| 56 | + A_y, D = handle_matrix_inputs(A, D) |
| 57 | + |
| 58 | + y0 = D.sqrt_inv @ x0 |
| 59 | + b_y = D.sqrt_inv @ b |
| 60 | + ys = sample_identity_diffusion(key, ts, y0, A_y, b_y, associative_scan) |
| 61 | + return jax.vmap(jnp.matmul, in_axes=(None, 0))(D.sqrt, ys) |
41 | 62 |
|
| 63 | + |
| 64 | +def sample_identity_diffusion( |
| 65 | + key: Array, |
| 66 | + ts: Array, |
| 67 | + x0: Array, |
| 68 | + A: Array | ProcessedDriftMatrix, |
| 69 | + b: Array, |
| 70 | + associative_scan: bool = True, |
| 71 | +) -> Array: |
| 72 | + if associative_scan: |
| 73 | + return _sample_identity_diffusion_associative_scan(key, ts, x0, A, b) |
| 74 | + else: |
| 75 | + return _sample_identity_diffusion_scan(key, ts, x0, A, b) |
| 76 | + |
| 77 | + |
| 78 | +def expm_vp(A, v, dt): |
| 79 | + out = A.eigvecs_inv @ v |
| 80 | + out = jnp.exp(-A.eigvals * dt) * out |
| 81 | + out = A.eigvecs @ out |
| 82 | + return out.real |
| 83 | + |
| 84 | + |
| 85 | +def transition_cov_sqrt_vp(A, v, dt): |
| 86 | + diag = ((1 - jnp.exp(-2 * A.sym_eigvals * dt)) / (2 * A.sym_eigvals)) ** 0.5 |
| 87 | + out = diag * v |
| 88 | + out = A.sym_eigvecs @ out |
| 89 | + return out.real |
| 90 | + |
| 91 | + |
| 92 | +def _sample_identity_diffusion_scan( |
| 93 | + key: Array, |
| 94 | + ts: Array, |
| 95 | + x0: Array, |
| 96 | + A: Array | ProcessedDriftMatrix, |
| 97 | + b: Array, |
| 98 | +) -> Array: |
42 | 99 | if isinstance(A, Array):
|
43 | 100 | A = preprocess_drift_matrix(A)
|
44 | 101 |
|
45 |
| - def expm_vp(v, dt): |
46 |
| - out = A.eigvecs_inv @ v |
47 |
| - out = jnp.exp(-A.eigvals * dt) * out |
48 |
| - out = A.eigvecs @ out |
49 |
| - return out.real |
50 |
| - |
51 | 102 | def transition_mean(x, dt):
|
52 |
| - return b + expm_vp(x - b, dt) |
| 103 | + return b + expm_vp(A, x - b, dt) |
53 | 104 |
|
54 |
| - def transition_cov_sqrt_vp(v, dt): |
55 |
| - diag = ((1 - jnp.exp(-2 * A.sym_eigvals * dt)) / (2 * A.sym_eigvals)) ** 0.5 |
56 |
| - out = diag * v |
57 |
| - out = A.sym_eigvecs @ out |
58 |
| - return out.real |
| 105 | + def next_x(x, dt, rv): |
| 106 | + return transition_mean(x, dt) + transition_cov_sqrt_vp(A, rv, dt) |
59 | 107 |
|
60 |
| - def next_x(x, dt, tkey): |
61 |
| - randv = jax.random.normal(tkey, shape=x.shape) |
62 |
| - return transition_mean(x, dt) + transition_cov_sqrt_vp(randv, dt) |
63 |
| - |
64 |
| - def scan_body(x_and_key, dt): |
65 |
| - x, rk = x_and_key |
66 |
| - rk, rk_use = jax.random.split(rk) |
67 |
| - x = next_x(x, dt, rk_use) |
68 |
| - return (x, rk), x |
| 108 | + def scan_body(carry, dt_and_rv): |
| 109 | + x = carry |
| 110 | + dt, rv = dt_and_rv |
| 111 | + new_x = next_x(x, dt, rv) |
| 112 | + return new_x, new_x |
69 | 113 |
|
70 | 114 | dts = jnp.diff(ts)
|
| 115 | + gauss_samps = jax.random.normal(key, (len(dts),) + x0.shape) |
| 116 | + |
| 117 | + # Stack dts and gauss_samps along a new axis |
| 118 | + dt_and_rv = (dts, gauss_samps) |
71 | 119 |
|
72 |
| - xs = scan(scan_body, (x0, key), dts)[1] |
| 120 | + _, xs = jax.lax.scan(scan_body, x0, dt_and_rv) |
73 | 121 | xs = jnp.concatenate([jnp.expand_dims(x0, axis=0), xs], axis=0)
|
74 | 122 | return xs
|
75 | 123 |
|
76 | 124 |
|
77 |
| -def sample( |
| 125 | +def _sample_identity_diffusion_associative_scan( |
78 | 126 | key: Array,
|
79 | 127 | ts: Array,
|
80 | 128 | x0: Array,
|
81 | 129 | A: Array | ProcessedDriftMatrix,
|
82 | 130 | b: Array,
|
83 |
| - D: Array | ProcessedDiffusionMatrix, |
84 | 131 | ) -> Array:
|
85 |
| - """Collects samples from the Ornstein-Uhlenbeck process, defined as: |
86 |
| -
|
87 |
| - dx = - A * (x - b) dt + sqrt(D) dW |
| 132 | + if isinstance(A, Array): |
| 133 | + A = preprocess_drift_matrix(A) |
88 | 134 |
|
89 |
| - by using exact diagonalization. |
| 135 | + dts = jnp.diff(ts) |
90 | 136 |
|
91 |
| - Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2), |
92 |
| - where T=len(ts). |
| 137 | + # transition_mean(x, dt) = b + expm_vp(A, x - b, dt) |
93 | 138 |
|
94 |
| - By default, this function does the preprocessing on A and D before the evaluation. |
95 |
| - However, the preprocessing can be done externally using thermox.preprocess |
96 |
| - the output of which can be used as A and D here, this will skip the preprocessing. |
| 139 | + gauss_samps = jax.random.normal(key, (len(dts),) + x0.shape) |
| 140 | + noise_terms = jax.vmap(lambda v, dt: transition_cov_sqrt_vp(A, v, dt))( |
| 141 | + gauss_samps, dts |
| 142 | + ) |
97 | 143 |
|
98 |
| - Args: |
99 |
| - key: Jax PRNGKey. |
100 |
| - ts: Times at which samples are collected. Includes time for x0. |
101 |
| - x0: Initial state of the process. |
102 |
| - A: Drift matrix (Array or thermox.ProcessedDriftMatrix). |
103 |
| - Note : If a thermox.ProcessedDriftMatrix instance is used as input, |
104 |
| - must be the transformed drift matrix, A_y, given by thermox.preprocess, |
105 |
| - not thermox.utils.preprocess_drift_matrix. |
106 |
| - b: Drift displacement vector. |
107 |
| - D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). |
| 144 | + @partial(jax.vmap, in_axes=(0, 0)) |
| 145 | + def binary_associative_operator(elem_a, elem_b): |
| 146 | + t_a, x_a = elem_a |
| 147 | + t_b, x_b = elem_b |
| 148 | + return t_a + t_b, expm_vp(A, x_a, t_b) + x_b |
108 | 149 |
|
109 |
| - Returns: |
110 |
| - Array-like, desired samples. |
111 |
| - shape: (len(ts), ) + x0.shape |
112 |
| - """ |
113 |
| - A_y, D = handle_matrix_inputs(A, D) |
| 150 | + scan_times = jnp.concatenate([ts[:1], dts], dtype=float) # [t0, dt1, dt2, ...] |
| 151 | + scan_input_values = jnp.concatenate( |
| 152 | + [x0[None] - b, noise_terms], axis=0 |
| 153 | + ) # Shift input by b |
| 154 | + scan_elems = (scan_times, scan_input_values) |
114 | 155 |
|
115 |
| - y0 = D.sqrt_inv @ x0 |
116 |
| - b_y = D.sqrt_inv @ b |
117 |
| - ys = sample_identity_diffusion(key, ts, y0, A_y, b_y) |
118 |
| - return jax.vmap(jnp.matmul, in_axes=(None, 0))(D.sqrt, ys) |
| 156 | + scan_output = jax.lax.associative_scan(binary_associative_operator, scan_elems) |
| 157 | + return scan_output[1] + b # Shift back by b |
0 commit comments