-
Notifications
You must be signed in to change notification settings - Fork 267
Closed
Description
Somewhat new to numpyro, though more familiar with Jax, so apologies if this is a known issue.
Modelling the boilerplate off of the baseball and time-series forcasting examples, working on a network inference problem (see here for an older jax version with discussion)
Setup looks like:
@jit
def jax_squareform(edgelist, n=n_nodes):
"""edgelist to adj. matrix"""
empty = np.zeros((n,n))
half = index_add(empty, index[np.triu_indices(n,1)], edgelist)
full = half+half.T
return full
def spread_jax(p,u_init,T):
"""
p: transmission probability matrix
u_init: initial infection node states
T: num. iterations to observe at
"""
def scan_fn(u, t):
u_add = lax.tanh(p@u)
u_p = 1-(1-u)*(1-u_add)
return u_p, u_add
u_end, u_adds = lax.scan(
scan_fn, u_init, np.arange(T)
)
return u_end, u_adds
def diff_kg(infections):
n_cascades, n_nodes = infections.shape
n_edges = n_nodes*(n_nodes-1)//2 # complete graph
# beta hyperpriors
u = ny.sample("u", dist.Uniform(np.zeros(n_edges),
np.ones(n_edges)))
v = ny.sample("v", dist.Gamma(np.ones(n_edges),
20*np.ones(n_edges)))
## Bayesian Inference and Decision Theory, Dr. Laskey (GMU)
Λ = ny.sample("Λ", dist.Beta(u*v, (1-u)*v))
s_ij = jax_squareform(Λ) # adjacency matrix to recover via inference
with ny.plate("n_cascades", n_cascades):
# infer source node
ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))
x0 = ny.sample("x0", dist.Categorical(ϕ))
# simulate ode and realize
# infectious = spread_jax(s_ij, x0, 0, 5)
infectious, hist = spread_jax(s_ij, x0, 5)
numpyro.sample("obs", dist.Bernoulli(probs=infectious),
obs=infections)
kernel = ny.infer.NUTS(diff_kg)
mcmc = ny.infer.MCMC(kernel, num_warmup=1500, num_samples=3000)
mcmc.run(PRNGKey(0), infections)
mcmc.print_summary()
samples = mcmc.get_samples()
Where infections
is an array with columns as nodes (0=susceptible, 1=infected) and rows as unique observations, simulated from a "ground-truth" network and different source nodes.
Running based on documentation examples results in the following error that I'm having quite a hard time parsing (sorry for the wall of text):
KeyError Traceback (most recent call last)
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/distributions/transforms.py in __call__(self, constraint)
526 try:
--> 527 factory = self._registry[type(constraint)]
528 except KeyError:
KeyError: <class 'numpyro.distributions.constraints._IntegerInterval'>
During handling of the above exception, another exception occurred:
NotImplementedError Traceback (most recent call last)
<ipython-input-33-5d6906d300b6> in <module>
1 kernel = ny.infer.NUTS(diff_kg)
2 mcmc = ny.infer.MCMC(kernel, num_warmup=1500, num_samples=3000)
----> 3 mcmc.run(PRNGKey(0), cascades)
4 mcmc.print_summary()
5 samples = mcmc.get_samples()
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
1194 collect_fields = tuple(set(('z', 'diverging') + tuple(extra_fields)))
1195 if self.num_chains == 1:
-> 1196 states_flat, last_state = self._single_chain_mcmc(rng_key, init_state, init_params,
1197 args, kwargs, collect_fields)
1198 states = tree_map(lambda x: x[np.newaxis, ...], states_flat)
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, collect_fields)
1067 def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, collect_fields=('z',)):
1068 if init_state is None:
-> 1069 init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
1070 model_args=args, model_kwargs=kwargs)
1071 if self.postprocess_fn is None:
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/mcmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
506 # Find valid initial params
507 if self._model and not init_params:
--> 508 init_params, is_valid = find_valid_initial_params(rng_key, self._model,
509 init_strategy=self._init_strategy,
510 param_as_improper=True,
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, param_as_improper, model_args, model_kwargs)
370 # Handle possible vectorization
371 if rng_key.ndim == 1:
--> 372 init_params, is_valid = _find_valid_params(rng_key)
373 else:
374 init_params, is_valid = lax.map(_find_valid_params, rng_key)
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key_)
359
360 def _find_valid_params(rng_key_):
--> 361 _, _, prototype_params, is_valid = init_state = body_fn((0, rng_key_, None, None))
362 # Early return if valid params found.
363 if not_jax_tracer(is_valid):
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/util.py in body_fn(state)
329 # Use `block` to not record sample primitives in `init_loc_fn`.
330 seeded_model = substitute(model, substitute_fn=block(seed(init_strategy, subkey)))
--> 331 model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
332 constrained_values, inv_transforms = {}, {}
333 for k, v in model_trace.items():
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
147 :return: `OrderedDict` containing the execution trace.
148 """
--> 149 self(*args, **kwargs)
150 return self.trace
151
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
61 def __call__(self, *args, **kwargs):
62 with self:
---> 63 return self.fn(*args, **kwargs)
64
65
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
61 def __call__(self, *args, **kwargs):
62 with self:
---> 63 return self.fn(*args, **kwargs)
64
65
<ipython-input-26-070713a497d6> in diff_kg(infections)
40 # infer source node
41 ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))
---> 42 x0 = ny.sample("x0", dist.Categorical(ϕ))
43
44 # simulate ode and realize
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in sample(name, fn, obs, rng_key, sample_shape)
103
104 # ...and use apply_stack to send it to the Messengers
--> 105 msg = apply_stack(initial_msg)
106 return msg['value']
107
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in apply_stack(msg)
20 pointer = 0
21 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 22 handler.process_message(msg)
23 # When a Messenger sets the "stop" field of a message,
24 # it prevents any Messengers above it on the stack from being applied.
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/handlers.py in process_message(self, msg)
430 msg['value'] = self.param_map[msg['name']]
431 else:
--> 432 base_value = self.substitute_fn(msg) if self.substitute_fn \
433 else self.base_param_map.get(msg['name'], None)
434 if base_value is not None:
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
61 def __call__(self, *args, **kwargs):
62 with self:
---> 63 return self.fn(*args, **kwargs)
64
65
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
61 def __call__(self, *args, **kwargs):
62 with self:
---> 63 return self.fn(*args, **kwargs)
64
65
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/util.py in _init_to_uniform(site, radius, skip_param)
226 fn = site['fn']
227 value = numpyro.sample('_init', fn, sample_shape=site['kwargs']['sample_shape'])
--> 228 base_transform = biject_to(fn.support)
229 unconstrained_value = numpyro.sample('_unconstrained_init', dist.Uniform(-radius, radius),
230 sample_shape=np.shape(base_transform.inv(value)))
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/distributions/transforms.py in __call__(self, constraint)
527 factory = self._registry[type(constraint)]
528 except KeyError:
--> 529 raise NotImplementedError
530
531 return factory(constraint)
NotImplementedError:
Metadata
Metadata
Assignees
Labels
No labels