-
Notifications
You must be signed in to change notification settings - Fork 281
Open
Labels
good first issueGood for newcomersGood for newcomers
Description
Notebook title: GLM-ordinal-regression
Notebook url: https://github.com/pymc-devs/pymc-examples/blob/main/examples/generalized_linear_models/GLM-ordinal-regression.ipynb
Issue description
Unable to run the cell 11 in the notebook. Gettting a jax error
/home/vlad/py310/lib/python3.10/site-packages/pymc/sampling/mcmc.py:243: UserWarning: Use of external NUTS sampler is still experimental
warnings.warn("Use of external NUTS sampler is still experimental", UserWarning)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[11], line 36
32 return idata, model
35 priors = {"sigma": 1, "beta": [0, 1], "mu": np.linspace(0, K, K - 1)}
---> 36 idata1, model1 = make_model(priors, model_spec=1)
37 idata2, model2 = make_model(priors, model_spec=2)
38 idata3, model3 = make_model(priors, model_spec=3)
Cell In[11], line 30, in make_model(priors, model_spec, constrained_uniform, logit)
28 else:
29 y_ = pm.OrderedProbit("y", cutpoints=cutpoints, eta=mu, observed=df.explicit_rating)
---> 30 idata = pm.sample(nuts_sampler="numpyro", idata_kwargs={"log_likelihood": True})
31 idata.extend(pm.sample_posterior_predictive(idata))
32 return idata, model
File ~/py310/lib/python3.10/site-packages/pymc/sampling/mcmc.py:571, in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
567 if not isinstance(step, NUTS):
568 raise ValueError(
569 "Model can not be sampled with NUTS alone. Your model is probably not continuous."
570 )
--> 571 return _sample_external_nuts(
572 sampler=nuts_sampler,
573 draws=draws,
574 tune=tune,
575 chains=chains,
576 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
577 random_seed=random_seed,
578 initvals=initvals,
579 model=model,
580 progressbar=progressbar,
581 idata_kwargs=idata_kwargs,
582 nuts_sampler_kwargs=nuts_sampler_kwargs,
583 **kwargs,
584 )
586 if isinstance(step, list):
587 step = CompoundStep(step)
File ~/py310/lib/python3.10/site-packages/pymc/sampling/mcmc.py:283, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
280 return idata
282 elif sampler == "numpyro":
--> 283 import pymc.sampling.jax as pymc_jax
285 idata = pymc_jax.sample_numpyro_nuts(
286 draws=draws,
287 tune=tune,
(...)
295 **nuts_sampler_kwargs,
296 )
297 return idata
File ~/py310/lib/python3.10/site-packages/pymc/sampling/jax.py:23
20 from typing import Any, Callable, Dict, List, Optional, Sequence, Union
22 import arviz as az
---> 23 import jax
24 import numpy as np
25 import pytensor.tensor as pt
File ~/py310/lib/python3.10/site-packages/jax/__init__.py:160
158 from jax import abstract_arrays as abstract_arrays
159 from jax import custom_derivatives as custom_derivatives
--> 160 from jax import custom_batching as custom_batching
161 from jax import custom_transpose as custom_transpose
162 from jax import api_util as api_util
File ~/py310/lib/python3.10/site-packages/jax/custom_batching.py:15
1 # Copyright 2021 The JAX Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
---> 15 from jax._src.custom_batching import (
16 custom_vmap,
17 sequential_vmap,
18 )
File ~/py310/lib/python3.10/site-packages/jax/_src/custom_batching.py:19
16 import operator
17 from typing import Callable, Optional
---> 19 from jax import lax
20 from jax._src import api
21 from jax._src import core
File ~/py310/lib/python3.10/site-packages/jax/lax/__init__.py:369
363 from jax._src.lax.ann import (
364 approx_max_k as approx_max_k,
365 approx_min_k as approx_min_k,
366 approx_top_k_p as approx_top_k_p
367 )
368 from jax._src.ad_util import stop_gradient_p as stop_gradient_p
--> 369 from jax.lax import linalg as linalg
371 from jax._src.pjit import with_sharding_constraint as with_sharding_constraint
372 from jax._src.pjit import sharding_constraint_p as sharding_constraint_p
File ~/py310/lib/python3.10/site-packages/jax/lax/linalg.py:15
1 # Copyright 2020 The JAX Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
---> 15 from jax._src.lax.linalg import (
16 cholesky,
17 cholesky_p,
18 eig,
19 eig_p,
20 eigh,
21 eigh_p,
22 hessenberg,
23 hessenberg_p,
24 lu,
25 lu_p,
26 lu_pivots_to_permutation,
27 householder_product,
28 householder_product_p,
29 qr,
30 qr_p,
31 svd,
32 svd_p,
33 triangular_solve,
34 triangular_solve_p,
35 tridiagonal,
36 tridiagonal_p,
37 tridiagonal_solve,
38 tridiagonal_solve_p,
39 schur,
40 schur_p
41 )
44 from jax._src.lax.qdwh import (
45 qdwh as qdwh
46 )
File ~/py310/lib/python3.10/site-packages/jax/_src/lax/linalg.py:37
35 from jax._src.interpreters import mlir
36 from jax._src.lax import control_flow
---> 37 from jax._src.lax import eigh as lax_eigh
38 from jax._src.lax import lax as lax_internal
39 from jax._src.lax import svd as lax_svd
File ~/py310/lib/python3.10/site-packages/jax/_src/lax/eigh.py:39
37 from jax._src.numpy import ufuncs
38 from jax import lax
---> 39 from jax._src.lax import qdwh
40 from jax._src.lax import linalg as lax_linalg
41 from jax._src.lax.stack import Stack
File ~/py310/lib/python3.10/site-packages/jax/_src/lax/qdwh.py:31
28 from typing import Optional, Tuple
30 import jax
---> 31 import jax.numpy as jnp
32 from jax import lax
33 from jax._src import core
File ~/py310/lib/python3.10/site-packages/jax/numpy/__init__.py:260
257 # TODO(phawkins): make this import unconditional after increasing the ml_dtypes
258 # minimum version.
259 import jax._src.numpy.lax_numpy
--> 260 if hasattr(jax._src.numpy.lax_numpy, "int4"):
261 from jax._src.numpy.lax_numpy import (
262 int4 as int4,
263 uint4 as uint4,
264 )
267 from jax._src.numpy.index_tricks import (
268 c_ as c_,
269 index_exp as index_exp,
(...)
273 s_ as s_,
274 )
AttributeError: partially initialized module 'jax' has no attribute '_src' (most likely due to a circular import)
Note that this issue tracker is about the contents in the notebooks, if
the notebook is instead triggering a bug or error in pymc, please
report to https://github.com/pymc-devs/pymc/issues instead
Expected output
If applicable, describe what should happen instead.
Proposed solution
If applicable, explain possible solutions and workarounds.
Metadata
Metadata
Assignees
Labels
good first issueGood for newcomersGood for newcomers