Closed
Description
We can get O(log(T)) time assuming O(T) cores for sample
and log_prob
by using jax.lax.associative_scan
.
Relevant tensorflow code here
We can get O(log(T)) time assuming O(T) cores for sample
and log_prob
by using jax.lax.associative_scan
.
Relevant tensorflow code here