Skip to content

Add helper to compute log_likelihood and stop computing it by default #6374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ jobs:
pymc/tests/sampling/test_forward.py
pymc/tests/sampling/test_population.py
pymc/tests/stats/test_convergence.py
pymc/tests/stats/test_log_likelihood.py

- |
pymc/tests/tuning/test_scaling.py
Expand Down
1 change: 1 addition & 0 deletions docs/source/api/misc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Other utils
.. autosummary::
:toctree: generated/

compute_log_likelihood
find_constrained_prior
DictToArrayBijection

Expand Down
256 changes: 189 additions & 67 deletions docs/source/learn/core_notebooks/model_comparison.ipynb

Large diffs are not rendered by default.

125 changes: 15 additions & 110 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires
from pytensor.graph.basic import Constant
from pytensor.tensor.sharedvar import SharedVariable
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1

import pymc

Expand Down Expand Up @@ -153,7 +152,7 @@ def __init__(
trace=None,
prior=None,
posterior_predictive=None,
log_likelihood=True,
log_likelihood=False,
predictions=None,
coords: Optional[CoordSpec] = None,
dims: Optional[DimSpec] = None,
Expand Down Expand Up @@ -246,68 +245,6 @@ def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrac
trace_posterior = self.trace[self.ntune :]
return trace_posterior, trace_warmup

def log_likelihood_vals_point(self, point, var, log_like_fun):
"""Compute log likelihood for each observed point."""
# TODO: This is a cheap hack; we should filter-out the correct
# variables some other way
point = {i.name: point[i.name] for i in log_like_fun.f.maker.inputs if i.name in point}
log_like_val = np.atleast_1d(log_like_fun(point))

if isinstance(var.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)):
try:
obs_data = extract_obs_data(self.model.rvs_to_values[var])
except TypeError:
warnings.warn(f"Could not extract data from symbolic observation {var}")

mask = obs_data.mask
if np.ndim(mask) > np.ndim(log_like_val):
mask = np.any(mask, axis=-1)
log_like_val = np.where(mask, np.nan, log_like_val)
return log_like_val

def _extract_log_likelihood(self, trace):
"""Compute log likelihood of each observation."""
if self.trace is None:
return None
if self.model is None:
return None

# TODO: We no longer need one function per observed variable
if self.log_likelihood is True:
cached = [
(
var,
self.model.compile_fn(
self.model.logp(var, sum=False)[0],
inputs=self.model.value_vars,
on_unused_input="ignore",
),
)
for var in self.model.observed_RVs
]
else:
cached = [
(
var,
self.model.compile_fn(
self.model.logp(var, sum=False)[0],
inputs=self.model.value_vars,
on_unused_input="ignore",
),
)
for var in self.model.observed_RVs
if var.name in self.log_likelihood
]
log_likelihood_dict = _DefaultTrace(len(trace.chains))
for var, log_like_fun in cached:
for k, chain in enumerate(trace.chains):
log_like_chain = [
self.log_likelihood_vals_point(point, var, log_like_fun)
for point in trace.points([chain])
]
log_likelihood_dict.insert(var.name, np.stack(log_like_chain), k)
return log_likelihood_dict.trace_dict

@requires("trace")
def posterior_to_xarray(self):
"""Convert the posterior to an xarray dataset."""
Expand Down Expand Up @@ -382,49 +319,6 @@ def sample_stats_to_xarray(self):
),
)

@requires("trace")
@requires("model")
def log_likelihood_to_xarray(self):
"""Extract log likelihood and log_p data from PyMC trace."""
if self.predictions or not self.log_likelihood:
return None
data_warmup = {}
data = {}
warn_msg = (
"Could not compute log_likelihood, it will be omitted. "
"Check your model object or set log_likelihood=False"
)
if self.posterior_trace:
try:
data = self._extract_log_likelihood(self.posterior_trace)
except TypeError:
warnings.warn(warn_msg)
if self.warmup_trace:
try:
data_warmup = self._extract_log_likelihood(self.warmup_trace)
except TypeError:
warnings.warn(warn_msg)
return (
dict_to_dataset(
data,
library=pymc,
dims=self.dims,
coords=self.coords,
skip_event_dims=True,
),
dict_to_dataset(
data_warmup,
library=pymc,
dims=self.dims,
coords=self.coords,
skip_event_dims=True,
),
)

return dict_to_dataset(
data, library=pymc, coords=self.coords, dims=self.dims, default_dims=self.sample_dims
)

@requires(["posterior_predictive"])
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
Expand Down Expand Up @@ -509,7 +403,6 @@ def to_inference_data(self):
id_dict = {
"posterior": self.posterior_to_xarray(),
"sample_stats": self.sample_stats_to_xarray(),
"log_likelihood": self.log_likelihood_to_xarray(),
"posterior_predictive": self.posterior_predictive_to_xarray(),
"predictions": self.predictions_to_xarray(),
**self.priors_to_xarray(),
Expand All @@ -519,15 +412,27 @@ def to_inference_data(self):
id_dict["predictions_constant_data"] = self.constant_data_to_xarray()
else:
id_dict["constant_data"] = self.constant_data_to_xarray()
return InferenceData(save_warmup=self.save_warmup, **id_dict)
idata = InferenceData(save_warmup=self.save_warmup, **id_dict)
if self.log_likelihood:
from pymc.stats.log_likelihood import compute_log_likelihood

idata = compute_log_likelihood(
idata,
var_names=None if self.log_likelihood is True else self.log_likelihood,
extend_inferencedata=True,
model=self.model,
sample_dims=self.sample_dims,
progressbar=False,
)
return idata


def to_inference_data(
trace: Optional["MultiTrace"] = None,
*,
prior: Optional[Mapping[str, Any]] = None,
posterior_predictive: Optional[Mapping[str, Any]] = None,
log_likelihood: Union[bool, Iterable[str]] = True,
log_likelihood: Union[bool, Iterable[str]] = False,
coords: Optional[CoordSpec] = None,
dims: Optional[DimSpec] = None,
sample_dims: Optional[List] = None,
Expand Down
4 changes: 2 additions & 2 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def sample_blackjax_nuts(
else:
idata_kwargs = idata_kwargs.copy()

if idata_kwargs.pop("log_likelihood", bool(model.observed_RVs)):
if idata_kwargs.pop("log_likelihood", False):
tic5 = datetime.now()
print("Computing Log Likelihood...", file=sys.stdout)
log_likelihood = _get_log_likelihood(
Expand Down Expand Up @@ -634,7 +634,7 @@ def sample_numpyro_nuts(
else:
idata_kwargs = idata_kwargs.copy()

if idata_kwargs.pop("log_likelihood", bool(model.observed_RVs)):
if idata_kwargs.pop("log_likelihood", False):
tic5 = datetime.now()
print("Computing Log Likelihood...", file=sys.stdout)
log_likelihood = _get_log_likelihood(
Expand Down
3 changes: 2 additions & 1 deletion pymc/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,6 @@
if not attr.startswith("__"):
setattr(sys.modules[__name__], attr, obj)

from pymc.stats.log_likelihood import compute_log_likelihood

__all__ = tuple(az.stats.__all__)
__all__ = ("compute_log_likelihood",) + tuple(az.stats.__all__)
130 changes: 130 additions & 0 deletions pymc/stats/log_likelihood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2022 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Sequence

import numpy as np

from arviz import InferenceData, dict_to_dataset
from fastprogress import progress_bar

import pymc

from pymc.backends.arviz import _DefaultTrace
from pymc.model import Model, modelcontext
from pymc.util import dataset_to_point_list

__all__ = ("compute_log_likelihood",)


def compute_log_likelihood(
idata: InferenceData,
*,
var_names: Optional[Sequence[str]] = None,
extend_inferencedata: bool = True,
model: Optional[Model] = None,
sample_dims: Sequence[str] = ("chain", "draw"),
progressbar=True,
):
"""Compute elemwise log_likelihood of model given InferenceData with posterior group

Parameters
----------
idata : InferenceData
InferenceData with posterior group
var_names : sequence of str, optional
List of Observed variable names for which to compute log_likelihood. Defaults to all observed variables
extend_inferencedata : bool, default True
Whether to extend the original InferenceData or return a new one
model : Model, optional
sample_dims : sequence of str, default ("chain", "draw")
progressbar : bool, default True

Returns
-------
idata : InferenceData
InferenceData with log_likelihood group

"""

posterior = idata["posterior"]

model = modelcontext(model)

if var_names is None:
observed_vars = model.observed_RVs
var_names = tuple(rv.name for rv in observed_vars)
else:
observed_vars = [model.named_vars[name] for name in var_names]
if not set(observed_vars).issubset(model.observed_RVs):
raise ValueError(f"var_names must refer to observed_RVs in the model. Got: {var_names}")

# We need to temporarily disable transforms, because the InferenceData only keeps the untransformed values
# pylint: disable=used-before-assignment
try:
original_rvs_to_values = model.rvs_to_values
original_rvs_to_transforms = model.rvs_to_transforms

model.rvs_to_values = {
rv: rv.clone() if rv not in model.observed_RVs else value
for rv, value in model.rvs_to_values.items()
}
model.rvs_to_transforms = {rv: None for rv in model.basic_RVs}

elemwise_loglike_fn = model.compile_fn(
inputs=model.value_vars,
outs=model.logp(vars=observed_vars, sum=False),
on_unused_input="ignore",
)
finally:
model.rvs_to_values = original_rvs_to_values
model.rvs_to_transforms = original_rvs_to_transforms
# pylint: enable=used-before-assignment

# Ignore Deterministics
posterior_values = posterior[[rv.name for rv in model.free_RVs]]
posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)
n_pts = len(posterior_pts)
loglike_dict = _DefaultTrace(n_pts)
indices = range(n_pts)
if progressbar:
indices = progress_bar(indices, total=n_pts, display=progressbar)

for idx in indices:
loglikes_pts = elemwise_loglike_fn(posterior_pts[idx])
for rv_name, rv_loglike in zip(var_names, loglikes_pts):
loglike_dict.insert(rv_name, rv_loglike, idx)

loglike_trace = loglike_dict.trace_dict
for key, array in loglike_trace.items():
loglike_trace[key] = array.reshape(
(*[len(coord) for coord in stacked_dims.values()], *array.shape[1:])
)

loglike_dataset = dict_to_dataset(
loglike_trace,
library=pymc,
dims={dname: list(dvals) for dname, dvals in model.named_vars_to_dims.items()},
coords={
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
for cname, cvals in model.coords.items()
},
default_dims=list(sample_dims),
skip_event_dims=True,
)

if extend_inferencedata:
idata.add_groups(dict(log_likelihood=loglike_dataset))
return idata
else:
return loglike_dataset
Loading