1
- import jax
2
1
import jax .numpy as jnp
3
- from thermox .sampler import sample , sample_identity_diffusion
4
2
from jax .lax import fori_loop
5
- from jax import Array
3
+ from jax import Array , random
4
+ from thermox .sampler import sample , sample_identity_diffusion
5
+ from thermox .utils import ProcessedDriftMatrix
6
6
7
7
8
8
def solve (
9
- A ,
9
+ A : Array | ProcessedDriftMatrix ,
10
10
b ,
11
11
num_samples : int = 10000 ,
12
12
dt : float = 1.0 ,
@@ -34,15 +34,15 @@ def solve(
34
34
Approximate solution, x, of the linear system.
35
35
"""
36
36
if key is None :
37
- key = jax . random .PRNGKey (0 )
37
+ key = random .PRNGKey (0 )
38
38
ts = jnp .arange (burnin , burnin + num_samples ) * dt
39
39
x0 = jnp .zeros_like (b )
40
40
samples = sample_identity_diffusion (key , ts , x0 , A , jnp .linalg .solve (A , b ))
41
41
return jnp .mean (samples , axis = 0 )
42
42
43
43
44
44
def inv (
45
- A ,
45
+ A : Array ,
46
46
num_samples : int = 10000 ,
47
47
dt : float = 1.0 ,
48
48
burnin : int = 0 ,
@@ -65,7 +65,7 @@ def inv(
65
65
Approximate inverse of A.
66
66
"""
67
67
if key is None :
68
- key = jax . random .PRNGKey (0 )
68
+ key = random .PRNGKey (0 )
69
69
ts = jnp .arange (burnin , burnin + num_samples ) * dt
70
70
b = jnp .zeros (A .shape [0 ])
71
71
x0 = jnp .zeros_like (b )
@@ -74,7 +74,7 @@ def inv(
74
74
75
75
76
76
def expnegm (
77
- A ,
77
+ A : Array ,
78
78
num_samples : int = 10000 ,
79
79
dt : float = 1.0 ,
80
80
burnin : int = 0 ,
@@ -100,7 +100,7 @@ def expnegm(
100
100
Approximate negative matrix exponential, exp(-A).
101
101
"""
102
102
if key is None :
103
- key = jax . random .PRNGKey (0 )
103
+ key = random .PRNGKey (0 )
104
104
105
105
A_shifted = (A + alpha * jnp .eye (A .shape [0 ])) / dt
106
106
B = A_shifted + A_shifted .T
@@ -113,7 +113,7 @@ def expnegm(
113
113
114
114
115
115
def expm (
116
- A ,
116
+ A : Array ,
117
117
num_samples : int = 10000 ,
118
118
dt : float = 1.0 ,
119
119
burnin : int = 0 ,
0 commit comments