12
12
from numpy .linalg import inv , pinv
13
13
import numpy as np
14
14
15
- from autograd import hessian , value_and_grad , elementwise_grad as egrad , grad
15
+ from jax import hessian , value_and_grad , grad , vmap , vjp
16
16
from autograd .differential_operators import make_jvp_reversemode
17
- from autograd . misc import flatten
18
- import autograd .numpy as anp
17
+ from jax . flatten_util import ravel_pytree as flatten
18
+ import jax .numpy as jnp
19
19
20
20
from scipy .optimize import minimize , root_scalar
21
21
from scipy .integrate import trapz
39
39
]
40
40
41
41
42
+ def egrad (g ):
43
+ # assumes grad w.r.t argnum=1
44
+ def wrapped (params , times ):
45
+ y , g_vjp = vjp (lambda times : g (params , times ), times )
46
+ (x_bar ,) = g_vjp (jnp .ones_like (y ))
47
+ return x_bar
48
+
49
+ return wrapped
50
+
51
+
42
52
class BaseFitter :
43
53
44
54
weights : np .ndarray
@@ -384,30 +394,30 @@ def _buffer_bounds(self, bounds: list[tuple[t.Optional[float], t.Optional[float]
384
394
yield (lb + self ._MIN_PARAMETER_VALUE , ub - self ._MIN_PARAMETER_VALUE )
385
395
386
396
def _cumulative_hazard (self , params , times ):
387
- return - anp .log (self ._survival_function (params , times ))
397
+ return - jnp .log (self ._survival_function (params , times ))
388
398
389
399
def _hazard (self , * args , ** kwargs ):
390
400
# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
391
- return egrad (self ._cumulative_hazard , argnum = 1 )(* args , ** kwargs )
401
+ return egrad (self ._cumulative_hazard )(* args , ** kwargs )
392
402
393
403
def _density (self , * args , ** kwargs ):
394
404
# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
395
- return egrad (self ._cumulative_density , argnum = 1 )(* args , ** kwargs )
405
+ return egrad (self ._cumulative_density )(* args , ** kwargs )
396
406
397
407
def _survival_function (self , params , times ):
398
- return anp .exp (- self ._cumulative_hazard (params , times ))
408
+ return jnp .exp (- self ._cumulative_hazard (params , times ))
399
409
400
410
def _cumulative_density (self , params , times ):
401
411
return 1 - self ._survival_function (params , times )
402
412
403
413
def _log_hazard (self , params , times ):
404
414
hz = self ._hazard (params , times )
405
- hz = anp .clip (hz , 1e-50 , np .inf )
406
- return anp .log (hz )
415
+ hz = jnp .clip (hz , 1e-50 , np .inf )
416
+ return jnp .log (hz )
407
417
408
418
def _log_1m_sf (self , params , times ):
409
419
# equal to log(cdf), but often easier to express with sf.
410
- return anp .log1p (- self ._survival_function (params , times ))
420
+ return jnp .log1p (- self ._survival_function (params , times ))
411
421
412
422
def _negative_log_likelihood_left_censoring (self , params , Ts , E , entry , weights ) -> float :
413
423
T = Ts [1 ]
@@ -449,8 +459,8 @@ def _negative_log_likelihood_interval_censoring(self, params, Ts, E, entry, weig
449
459
ll
450
460
+ (
451
461
censored_weights
452
- * anp .log (
453
- anp .clip (
462
+ * jnp .log (
463
+ jnp .clip (
454
464
self ._survival_function (params , censored_starts ) - self ._survival_function (params , censored_stops ),
455
465
1e-25 ,
456
466
1 - 1e-25 ,
@@ -1391,23 +1401,23 @@ def _check_values_pre_fitting(self, df, T, E, weights, entries):
1391
1401
utils .check_entry_times (T , entries )
1392
1402
1393
1403
def _cumulative_hazard (self , params , T , Xs ):
1394
- return - anp .log (self ._survival_function (params , T , Xs ))
1404
+ return - jnp .log (self ._survival_function (params , T , Xs ))
1395
1405
1396
1406
def _hazard (self , params , T , Xs ):
1397
- return egrad (self ._cumulative_hazard , argnum = 1 )(params , T , Xs ) # pylint: disable=unexpected-keyword-arg
1407
+ return egrad (self ._cumulative_hazard )(params , T , Xs ) # pylint: disable=unexpected-keyword-arg
1398
1408
1399
1409
def _log_hazard (self , params , T , Xs ):
1400
1410
# can be overwritten to improve convergence, see example in WeibullAFTFitter
1401
1411
hz = self ._hazard (params , T , Xs )
1402
- hz = anp .clip (hz , 1e-20 , np .inf )
1403
- return anp .log (hz )
1412
+ hz = jnp .clip (hz , 1e-20 , np .inf )
1413
+ return jnp .log (hz )
1404
1414
1405
1415
def _log_1m_sf (self , params , T , Xs ):
1406
1416
# equal to log(cdf), but often easier to express with sf.
1407
- return anp .log1p (- self ._survival_function (params , T , Xs ))
1417
+ return jnp .log1p (- self ._survival_function (params , T , Xs ))
1408
1418
1409
1419
def _survival_function (self , params , T , Xs ):
1410
- return anp .clip (anp .exp (- self ._cumulative_hazard (params , T , Xs )), 1e-12 , 1 - 1e-12 )
1420
+ return jnp .clip (jnp .exp (- self ._cumulative_hazard (params , T , Xs )), 1e-12 , 1 - 1e-12 )
1411
1421
1412
1422
def _log_likelihood_right_censoring (self , params , Ts : tuple , E , W , entries , Xs ) -> float :
1413
1423
@@ -1422,7 +1432,7 @@ def _log_likelihood_right_censoring(self, params, Ts: tuple, E, W, entries, Xs)
1422
1432
ll = ll + (W * E * log_hz ).sum ()
1423
1433
ll = ll + - (W * cum_hz ).sum ()
1424
1434
ll = ll + (W [non_zero_entries ] * delayed_entries ).sum ()
1425
- ll = ll / anp .sum (W )
1435
+ ll = ll / jnp .sum (W )
1426
1436
return ll
1427
1437
1428
1438
def _log_likelihood_left_censoring (self , params , Ts , E , W , entries , Xs ) -> float :
@@ -1438,16 +1448,16 @@ def _log_likelihood_left_censoring(self, params, Ts, E, W, entries, Xs) -> float
1438
1448
ll = 0
1439
1449
ll = (W * E * (log_hz - cum_haz - log_1m_sf )).sum () + (W * log_1m_sf ).sum ()
1440
1450
ll = ll + (W [non_zero_entries ] * delayed_entries ).sum ()
1441
- ll = ll / anp .sum (W )
1451
+ ll = ll / jnp .sum (W )
1442
1452
return ll
1443
1453
1444
1454
def _log_likelihood_interval_censoring (self , params , Ts , E , W , entries , Xs ) -> float :
1445
1455
1446
1456
start , stop = Ts
1447
1457
non_zero_entries = entries > 0
1448
1458
observed_deaths = self ._log_hazard (params , stop [E ], Xs .filter (E )) - self ._cumulative_hazard (params , stop [E ], Xs .filter (E ))
1449
- censored_interval_deaths = anp .log (
1450
- anp .clip (
1459
+ censored_interval_deaths = jnp .log (
1460
+ jnp .clip (
1451
1461
self ._survival_function (params , start [~ E ], Xs .filter (~ E ))
1452
1462
- self ._survival_function (params , stop [~ E ], Xs .filter (~ E )),
1453
1463
1e-25 ,
@@ -1460,7 +1470,7 @@ def _log_likelihood_interval_censoring(self, params, Ts, E, W, entries, Xs) -> f
1460
1470
ll = ll + (W [E ] * observed_deaths ).sum ()
1461
1471
ll = ll + (W [~ E ] * censored_interval_deaths ).sum ()
1462
1472
ll = ll + (W [non_zero_entries ] * delayed_entries ).sum ()
1463
- ll = ll / anp .sum (W )
1473
+ ll = ll / jnp .sum (W )
1464
1474
return ll
1465
1475
1466
1476
@utils .CensoringType .left_censoring
@@ -1885,7 +1895,7 @@ def _add_penalty(self, params: dict, neg_ll: float):
1885
1895
params_array = params_array [~ self ._cols_to_not_penalize ]
1886
1896
if (isinstance (self .penalizer , np .ndarray ) or self .penalizer > 0 ) and self .l1_ratio > 0 :
1887
1897
penalty = (
1888
- self .l1_ratio * (self .penalizer * anp .abs (params_array )).sum ()
1898
+ self .l1_ratio * (self .penalizer * jnp .abs (params_array )).sum ()
1889
1899
+ 0.5 * (1.0 - self .l1_ratio ) * (self .penalizer * (params_array ) ** 2 ).sum ()
1890
1900
)
1891
1901
0 commit comments