Skip to content

Commit ba06aab

Browse files
wip for jax
1 parent 2bd0627 commit ba06aab

File tree

5 files changed

+41
-32
lines changed

5 files changed

+41
-32
lines changed

lifelines/fitters/__init__.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
from numpy.linalg import inv, pinv
1313
import numpy as np
1414

15-
from autograd import hessian, value_and_grad, elementwise_grad as egrad, grad
15+
from jax import hessian, value_and_grad, grad, vmap, vjp
1616
from autograd.differential_operators import make_jvp_reversemode
17-
from autograd.misc import flatten
18-
import autograd.numpy as anp
17+
from jax.flatten_util import ravel_pytree as flatten
18+
import jax.numpy as jnp
1919

2020
from scipy.optimize import minimize, root_scalar
2121
from scipy.integrate import trapz
@@ -39,6 +39,16 @@
3939
]
4040

4141

42+
def egrad(g):
43+
# assumes grad w.r.t argnum=1
44+
def wrapped(params, times):
45+
y, g_vjp = vjp(lambda times: g(params, times), times)
46+
(x_bar,) = g_vjp(jnp.ones_like(y))
47+
return x_bar
48+
49+
return wrapped
50+
51+
4252
class BaseFitter:
4353

4454
weights: np.ndarray
@@ -384,30 +394,30 @@ def _buffer_bounds(self, bounds: list[tuple[t.Optional[float], t.Optional[float]
384394
yield (lb + self._MIN_PARAMETER_VALUE, ub - self._MIN_PARAMETER_VALUE)
385395

386396
def _cumulative_hazard(self, params, times):
387-
return -anp.log(self._survival_function(params, times))
397+
return -jnp.log(self._survival_function(params, times))
388398

389399
def _hazard(self, *args, **kwargs):
390400
# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
391-
return egrad(self._cumulative_hazard, argnum=1)(*args, **kwargs)
401+
return egrad(self._cumulative_hazard)(*args, **kwargs)
392402

393403
def _density(self, *args, **kwargs):
394404
# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
395-
return egrad(self._cumulative_density, argnum=1)(*args, **kwargs)
405+
return egrad(self._cumulative_density)(*args, **kwargs)
396406

397407
def _survival_function(self, params, times):
398-
return anp.exp(-self._cumulative_hazard(params, times))
408+
return jnp.exp(-self._cumulative_hazard(params, times))
399409

400410
def _cumulative_density(self, params, times):
401411
return 1 - self._survival_function(params, times)
402412

403413
def _log_hazard(self, params, times):
404414
hz = self._hazard(params, times)
405-
hz = anp.clip(hz, 1e-50, np.inf)
406-
return anp.log(hz)
415+
hz = jnp.clip(hz, 1e-50, np.inf)
416+
return jnp.log(hz)
407417

408418
def _log_1m_sf(self, params, times):
409419
# equal to log(cdf), but often easier to express with sf.
410-
return anp.log1p(-self._survival_function(params, times))
420+
return jnp.log1p(-self._survival_function(params, times))
411421

412422
def _negative_log_likelihood_left_censoring(self, params, Ts, E, entry, weights) -> float:
413423
T = Ts[1]
@@ -449,8 +459,8 @@ def _negative_log_likelihood_interval_censoring(self, params, Ts, E, entry, weig
449459
ll
450460
+ (
451461
censored_weights
452-
* anp.log(
453-
anp.clip(
462+
* jnp.log(
463+
jnp.clip(
454464
self._survival_function(params, censored_starts) - self._survival_function(params, censored_stops),
455465
1e-25,
456466
1 - 1e-25,
@@ -1391,23 +1401,23 @@ def _check_values_pre_fitting(self, df, T, E, weights, entries):
13911401
utils.check_entry_times(T, entries)
13921402

13931403
def _cumulative_hazard(self, params, T, Xs):
1394-
return -anp.log(self._survival_function(params, T, Xs))
1404+
return -jnp.log(self._survival_function(params, T, Xs))
13951405

13961406
def _hazard(self, params, T, Xs):
1397-
return egrad(self._cumulative_hazard, argnum=1)(params, T, Xs) # pylint: disable=unexpected-keyword-arg
1407+
return egrad(self._cumulative_hazard)(params, T, Xs) # pylint: disable=unexpected-keyword-arg
13981408

13991409
def _log_hazard(self, params, T, Xs):
14001410
# can be overwritten to improve convergence, see example in WeibullAFTFitter
14011411
hz = self._hazard(params, T, Xs)
1402-
hz = anp.clip(hz, 1e-20, np.inf)
1403-
return anp.log(hz)
1412+
hz = jnp.clip(hz, 1e-20, np.inf)
1413+
return jnp.log(hz)
14041414

14051415
def _log_1m_sf(self, params, T, Xs):
14061416
# equal to log(cdf), but often easier to express with sf.
1407-
return anp.log1p(-self._survival_function(params, T, Xs))
1417+
return jnp.log1p(-self._survival_function(params, T, Xs))
14081418

14091419
def _survival_function(self, params, T, Xs):
1410-
return anp.clip(anp.exp(-self._cumulative_hazard(params, T, Xs)), 1e-12, 1 - 1e-12)
1420+
return jnp.clip(jnp.exp(-self._cumulative_hazard(params, T, Xs)), 1e-12, 1 - 1e-12)
14111421

14121422
def _log_likelihood_right_censoring(self, params, Ts: tuple, E, W, entries, Xs) -> float:
14131423

@@ -1422,7 +1432,7 @@ def _log_likelihood_right_censoring(self, params, Ts: tuple, E, W, entries, Xs)
14221432
ll = ll + (W * E * log_hz).sum()
14231433
ll = ll + -(W * cum_hz).sum()
14241434
ll = ll + (W[non_zero_entries] * delayed_entries).sum()
1425-
ll = ll / anp.sum(W)
1435+
ll = ll / jnp.sum(W)
14261436
return ll
14271437

14281438
def _log_likelihood_left_censoring(self, params, Ts, E, W, entries, Xs) -> float:
@@ -1438,16 +1448,16 @@ def _log_likelihood_left_censoring(self, params, Ts, E, W, entries, Xs) -> float
14381448
ll = 0
14391449
ll = (W * E * (log_hz - cum_haz - log_1m_sf)).sum() + (W * log_1m_sf).sum()
14401450
ll = ll + (W[non_zero_entries] * delayed_entries).sum()
1441-
ll = ll / anp.sum(W)
1451+
ll = ll / jnp.sum(W)
14421452
return ll
14431453

14441454
def _log_likelihood_interval_censoring(self, params, Ts, E, W, entries, Xs) -> float:
14451455

14461456
start, stop = Ts
14471457
non_zero_entries = entries > 0
14481458
observed_deaths = self._log_hazard(params, stop[E], Xs.filter(E)) - self._cumulative_hazard(params, stop[E], Xs.filter(E))
1449-
censored_interval_deaths = anp.log(
1450-
anp.clip(
1459+
censored_interval_deaths = jnp.log(
1460+
jnp.clip(
14511461
self._survival_function(params, start[~E], Xs.filter(~E))
14521462
- self._survival_function(params, stop[~E], Xs.filter(~E)),
14531463
1e-25,
@@ -1460,7 +1470,7 @@ def _log_likelihood_interval_censoring(self, params, Ts, E, W, entries, Xs) -> f
14601470
ll = ll + (W[E] * observed_deaths).sum()
14611471
ll = ll + (W[~E] * censored_interval_deaths).sum()
14621472
ll = ll + (W[non_zero_entries] * delayed_entries).sum()
1463-
ll = ll / anp.sum(W)
1473+
ll = ll / jnp.sum(W)
14641474
return ll
14651475

14661476
@utils.CensoringType.left_censoring
@@ -1885,7 +1895,7 @@ def _add_penalty(self, params: dict, neg_ll: float):
18851895
params_array = params_array[~self._cols_to_not_penalize]
18861896
if (isinstance(self.penalizer, np.ndarray) or self.penalizer > 0) and self.l1_ratio > 0:
18871897
penalty = (
1888-
self.l1_ratio * (self.penalizer * anp.abs(params_array)).sum()
1898+
self.l1_ratio * (self.penalizer * jnp.abs(params_array)).sum()
18891899
+ 0.5 * (1.0 - self.l1_ratio) * (self.penalizer * (params_array) ** 2).sum()
18901900
)
18911901

lifelines/fitters/exponential_fitter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22
import numpy as np
3-
from autograd import numpy as anp
3+
from jax import numpy as jnp
44
from lifelines.fitters import KnownModelParametricUnivariateFitter
55

66

@@ -77,4 +77,4 @@ def _cumulative_hazard(self, params, times):
7777

7878
def _log_hazard(self, params, times):
7979
lambda_ = params[0]
80-
return -anp.log(lambda_)
80+
return -jnp.log(lambda_)

lifelines/fitters/mixins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def check_assumptions(
3030
plot_n_bootstraps: int = 15,
3131
columns: Optional[List[str]] = None,
3232
raise_on_fail: bool = False,
33-
) -> None:
33+
) -> list:
3434
"""
3535
Use this function to test the proportional hazards assumption. See usage example at
3636
https://lifelines.readthedocs.io/en/latest/jupyter_notebooks/Proportional%20hazard%20assumption.html

lifelines/tests/test_estimation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,8 +1149,8 @@ def test_against_reliability_software(self):
11491149

11501150
class TestExponentialFitter:
11511151
def test_fit_computes_correct_lambda_(self):
1152-
T = np.array([10, 10, 10, 10], dtype=float)
1153-
E = np.array([1, 1, 1, 0], dtype=float)
1152+
T = np.array([10, 20, 10, 10, 5, 10], dtype=float)
1153+
E = np.array([1, 1, 1, 0, 0, 1], dtype=float)
11541154
enf = ExponentialFitter()
11551155
enf.fit(T, E)
11561156
assert abs(enf.lambda_ - (T.sum() / E.sum())) < 1e-4

reqs/base-requirements.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
numpy>=1.14.0,<2.0
1+
numpy>=1.14.0
22
scipy>=1.2.0
33
pandas>=1.2.0
44
matplotlib>=3.0
5-
autograd>=1.5
6-
autograd-gamma>=0.3
75
formulaic>=0.2.2
6+
jax[cpu]>=0.4.0

0 commit comments

Comments
 (0)