Skip to content

Commit cb85b64

Browse files
Merge pull request #27 from normal-computing/linalg-typehints
Add type hints to linalg matrix inputs
2 parents 7b88fb1 + 6b9ed8f commit cb85b64

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "thermox"
33
version = "0.0.1"
4-
description = "OU Processes and Linear Algebra with JAX"
4+
description = "Exact OU processes with JAX"
55
readme = "README.md"
66
requires-python =">=3.9"
77
license = {text = "Apache-2.0"}

thermox/linalg.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
import jax
21
import jax.numpy as jnp
3-
from thermox.sampler import sample, sample_identity_diffusion
42
from jax.lax import fori_loop
5-
from jax import Array
3+
from jax import Array, random
4+
from thermox.sampler import sample, sample_identity_diffusion
5+
from thermox.utils import ProcessedDriftMatrix
66

77

88
def solve(
9-
A,
9+
A: Array | ProcessedDriftMatrix,
1010
b,
1111
num_samples: int = 10000,
1212
dt: float = 1.0,
@@ -34,15 +34,15 @@ def solve(
3434
Approximate solution, x, of the linear system.
3535
"""
3636
if key is None:
37-
key = jax.random.PRNGKey(0)
37+
key = random.PRNGKey(0)
3838
ts = jnp.arange(burnin, burnin + num_samples) * dt
3939
x0 = jnp.zeros_like(b)
4040
samples = sample_identity_diffusion(key, ts, x0, A, jnp.linalg.solve(A, b))
4141
return jnp.mean(samples, axis=0)
4242

4343

4444
def inv(
45-
A,
45+
A: Array,
4646
num_samples: int = 10000,
4747
dt: float = 1.0,
4848
burnin: int = 0,
@@ -65,7 +65,7 @@ def inv(
6565
Approximate inverse of A.
6666
"""
6767
if key is None:
68-
key = jax.random.PRNGKey(0)
68+
key = random.PRNGKey(0)
6969
ts = jnp.arange(burnin, burnin + num_samples) * dt
7070
b = jnp.zeros(A.shape[0])
7171
x0 = jnp.zeros_like(b)
@@ -74,7 +74,7 @@ def inv(
7474

7575

7676
def expnegm(
77-
A,
77+
A: Array,
7878
num_samples: int = 10000,
7979
dt: float = 1.0,
8080
burnin: int = 0,
@@ -100,7 +100,7 @@ def expnegm(
100100
Approximate negative matrix exponential, exp(-A).
101101
"""
102102
if key is None:
103-
key = jax.random.PRNGKey(0)
103+
key = random.PRNGKey(0)
104104

105105
A_shifted = (A + alpha * jnp.eye(A.shape[0])) / dt
106106
B = A_shifted + A_shifted.T
@@ -113,7 +113,7 @@ def expnegm(
113113

114114

115115
def expm(
116-
A,
116+
A: Array,
117117
num_samples: int = 10000,
118118
dt: float = 1.0,
119119
burnin: int = 0,

0 commit comments

Comments
 (0)