Skip to content

Obscure NotImplementedError for Categorical #545

@rtbs-dev

Description

@rtbs-dev

Somewhat new to numpyro, though more familiar with Jax, so apologies if this is a known issue.

Modelling the boilerplate off of the baseball and time-series forcasting examples, working on a network inference problem (see here for an older jax version with discussion)

Setup looks like:

@jit 
def jax_squareform(edgelist, n=n_nodes):
    """edgelist to adj. matrix"""
    empty = np.zeros((n,n))
    half = index_add(empty, index[np.triu_indices(n,1)], edgelist)
    full = half+half.T
    return full


def spread_jax(p,u_init,T):
    """
    p: transmission probability matrix
    u_init: initial infection node states
    T: num. iterations to observe at
    """
    def scan_fn(u, t):
        u_add = lax.tanh(p@u)
        u_p = 1-(1-u)*(1-u_add)
        return u_p, u_add
    u_end, u_adds = lax.scan(
        scan_fn, u_init, np.arange(T) 
    )
    return u_end, u_adds


def diff_kg(infections):
    n_cascades, n_nodes  = infections.shape
    n_edges = n_nodes*(n_nodes-1)//2 # complete graph
    
    # beta hyperpriors
    u = ny.sample("u", dist.Uniform(np.zeros(n_edges), 
                                         np.ones(n_edges)))
    v = ny.sample("v", dist.Gamma(np.ones(n_edges),
                                       20*np.ones(n_edges)))
    ## Bayesian Inference and Decision Theory, Dr. Laskey (GMU)
    Λ = ny.sample("Λ", dist.Beta(u*v, (1-u)*v))
    s_ij = jax_squareform(Λ)  # adjacency matrix to recover via inference
    
    
    with ny.plate("n_cascades", n_cascades):
        # infer source node
        ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))  
        x0 = ny.sample("x0", dist.Categorical(ϕ))
        
        # simulate ode and realize
#         infectious = spread_jax(s_ij, x0, 0, 5)
        infectious, hist = spread_jax(s_ij, x0, 5)
        numpyro.sample("obs", dist.Bernoulli(probs=infectious), 
                       obs=infections)

kernel = ny.infer.NUTS(diff_kg)
mcmc = ny.infer.MCMC(kernel, num_warmup=1500, num_samples=3000)
mcmc.run(PRNGKey(0), infections)
mcmc.print_summary()
samples = mcmc.get_samples()

Where infections is an array with columns as nodes (0=susceptible, 1=infected) and rows as unique observations, simulated from a "ground-truth" network and different source nodes.

Running based on documentation examples results in the following error that I'm having quite a hard time parsing (sorry for the wall of text):

KeyError                                  Traceback (most recent call last)
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/distributions/transforms.py in __call__(self, constraint)
    526         try:
--> 527             factory = self._registry[type(constraint)]
    528         except KeyError:

KeyError: <class 'numpyro.distributions.constraints._IntegerInterval'>

During handling of the above exception, another exception occurred:

NotImplementedError                       Traceback (most recent call last)
<ipython-input-33-5d6906d300b6> in <module>
      1 kernel = ny.infer.NUTS(diff_kg)
      2 mcmc = ny.infer.MCMC(kernel, num_warmup=1500, num_samples=3000)
----> 3 mcmc.run(PRNGKey(0), cascades)
      4 mcmc.print_summary()
      5 samples = mcmc.get_samples()

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
   1194         collect_fields = tuple(set(('z', 'diverging') + tuple(extra_fields)))
   1195         if self.num_chains == 1:
-> 1196             states_flat, last_state = self._single_chain_mcmc(rng_key, init_state, init_params,
   1197                                                               args, kwargs, collect_fields)
   1198             states = tree_map(lambda x: x[np.newaxis, ...], states_flat)

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, collect_fields)
   1067     def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, collect_fields=('z',)):
   1068         if init_state is None:
-> 1069             init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
   1070                                            model_args=args, model_kwargs=kwargs)
   1071         if self.postprocess_fn is None:

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/mcmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    506         # Find valid initial params
    507         if self._model and not init_params:
--> 508             init_params, is_valid = find_valid_initial_params(rng_key, self._model,
    509                                                               init_strategy=self._init_strategy,
    510                                                               param_as_improper=True,

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, param_as_improper, model_args, model_kwargs)
    370     # Handle possible vectorization
    371     if rng_key.ndim == 1:
--> 372         init_params, is_valid = _find_valid_params(rng_key)
    373     else:
    374         init_params, is_valid = lax.map(_find_valid_params, rng_key)

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key_)
    359 
    360     def _find_valid_params(rng_key_):
--> 361         _, _, prototype_params, is_valid = init_state = body_fn((0, rng_key_, None, None))
    362         # Early return if valid params found.
    363         if not_jax_tracer(is_valid):

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/util.py in body_fn(state)
    329         # Use `block` to not record sample primitives in `init_loc_fn`.
    330         seeded_model = substitute(model, substitute_fn=block(seed(init_strategy, subkey)))
--> 331         model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
    332         constrained_values, inv_transforms = {}, {}
    333         for k, v in model_trace.items():

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
    147         :return: `OrderedDict` containing the execution trace.
    148         """
--> 149         self(*args, **kwargs)
    150         return self.trace
    151 

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     61     def __call__(self, *args, **kwargs):
     62         with self:
---> 63             return self.fn(*args, **kwargs)
     64 
     65 

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     61     def __call__(self, *args, **kwargs):
     62         with self:
---> 63             return self.fn(*args, **kwargs)
     64 
     65 

<ipython-input-26-070713a497d6> in diff_kg(infections)
     40         # infer source node
     41         ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))
---> 42         x0 = ny.sample("x0", dist.Categorical(ϕ))
     43 
     44         # simulate ode and realize

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in sample(name, fn, obs, rng_key, sample_shape)
    103 
    104     # ...and use apply_stack to send it to the Messengers
--> 105     msg = apply_stack(initial_msg)
    106     return msg['value']
    107 

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in apply_stack(msg)
     20     pointer = 0
     21     for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 22         handler.process_message(msg)
     23         # When a Messenger sets the "stop" field of a message,
     24         # it prevents any Messengers above it on the stack from being applied.

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/handlers.py in process_message(self, msg)
    430                 msg['value'] = self.param_map[msg['name']]
    431         else:
--> 432             base_value = self.substitute_fn(msg) if self.substitute_fn \
    433                 else self.base_param_map.get(msg['name'], None)
    434             if base_value is not None:

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     61     def __call__(self, *args, **kwargs):
     62         with self:
---> 63             return self.fn(*args, **kwargs)
     64 
     65 

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     61     def __call__(self, *args, **kwargs):
     62         with self:
---> 63             return self.fn(*args, **kwargs)
     64 
     65 

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/util.py in _init_to_uniform(site, radius, skip_param)
    226             fn = site['fn']
    227         value = numpyro.sample('_init', fn, sample_shape=site['kwargs']['sample_shape'])
--> 228         base_transform = biject_to(fn.support)
    229         unconstrained_value = numpyro.sample('_unconstrained_init', dist.Uniform(-radius, radius),
    230                                              sample_shape=np.shape(base_transform.inv(value)))

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/distributions/transforms.py in __call__(self, constraint)
    527             factory = self._registry[type(constraint)]
    528         except KeyError:
--> 529             raise NotImplementedError
    530 
    531         return factory(constraint)

NotImplementedError: 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions