From dfa75d01e71e000dcbf448893823e416505501ae Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Thu, 8 Jul 2021 07:39:24 +0200 Subject: [PATCH 01/13] new simple pretty console printing for Model and RandomVariable --- pymc3/distributions/distribution.py | 14 ++- pymc3/model.py | 5 + pymc3/printing.py | 163 ++++++++++++++++++++++++++++ 3 files changed, 181 insertions(+), 1 deletion(-) create mode 100644 pymc3/printing.py diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 2dfa0d01c0..f47568de7b 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextvars +import functools import inspect import multiprocessing import sys @@ -41,6 +42,7 @@ resize_from_dims, resize_from_observed, ) +from pymc3.printing import str_repr from pymc3.util import UNSET, get_repr_for_variable from pymc3.vartypes import string_types @@ -222,7 +224,17 @@ def __new__( # Assigning the testval earlier causes trouble because the RV may not be created with the final shape already. rv_out.tag.test_value = initval - return model.register_rv(rv_out, name, observed, total_size, dims=dims, transform=transform) + rv_out = model.register_rv( + rv_out, name, observed, total_size, dims=dims, transform=transform + ) + + # add in pretty-printing support + rv_out.str_repr = types.MethodType(str_repr, rv_out) + rv_out._repr_latex_ = types.MethodType( + functools.partial(str_repr, formatting="latex"), rv_out + ) + + return rv_out @classmethod def dist( diff --git a/pymc3/model.py b/pymc3/model.py index 81c6f4f437..962417dd6c 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -15,6 +15,7 @@ import collections import itertools import threading +import types import warnings from sys import modules @@ -668,6 +669,10 @@ def __init__( self.deterministics = treelist() self.potentials = treelist() + from pymc3.printing import str_repr + + self.str_repr = types.MethodType(str_repr, self) + @property def model(self): return self diff --git a/pymc3/printing.py b/pymc3/printing.py new file mode 100644 index 0000000000..3c9bf7d0ec --- /dev/null +++ b/pymc3/printing.py @@ -0,0 +1,163 @@ +# Copyright 2021 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools + +from functools import singledispatch +from typing import Union + +from aesara.graph.basic import walk +from aesara.tensor.basic import TensorVariable, Variable +from aesara.tensor.elemwise import DimShuffle +from aesara.tensor.random.basic import RandomVariable +from aesara.tensor.var import TensorConstant + +from pymc3.model import Model + + +@singledispatch +def str_repr(rv: RandomVariable, formatting: str = "plain", include_params: bool = True) -> str: + """Make a human-readable string representation of a RandomVariable in a model, either + LaTeX or plain, optionally with distribution parameter values included.""" + + if include_params: + # first 3 args are always (rng, size, dtype), rest is relevant for distribution + dist_args = [_str_for_input_var(x, formatting=formatting) for x in rv.owner.inputs[3:]] + + print_name = rv.name if rv.name is not None else "" + if "latex" in formatting: + print_name = r"\text{" + _latex_escape(print_name) + "}" + dist_name = rv.owner.op._print_name[1] + if include_params: + return r"${} \sim {}({})$".format(print_name, dist_name, ",~".join(dist_args)) + else: + return fr"${print_name} \sim {dist_name}$" + else: # plain + dist_name = rv.owner.op._print_name[0] + if include_params: + return r"{} ~ {}({})".format(print_name, dist_name, ", ".join(dist_args)) + else: + return fr"{print_name} ~ {dist_name}" + + +@str_repr.register +def _(model: Model, formatting: str = "plain", include_params: bool = True) -> str: + """Make a human-readable string representation of Model, listing all random variables + and their distributions, optionally including parameter values.""" + all_rv = itertools.chain(model.unobserved_RVs, model.observed_RVs) + + rv_reprs = [rv.str_repr(formatting=formatting, include_params=include_params) for rv in all_rv] + rv_reprs = [rv_repr for rv_repr in rv_reprs if "TransformedDistribution()" not in rv_repr] + if "latex" in formatting: + rv_reprs = [ + rv_repr.replace(r"\sim", r"&\sim &").strip("$") + for rv_repr in rv_reprs + if rv_repr is not None + ] + return r"""$$ + \begin{{array}}{{rcl}} + {} + \end{{array}} + $$""".format( + "\\\\".join(rv_reprs) + ) + else: + # align vars on their ~ + names = [s[: s.index("~") - 1] for s in rv_reprs] + distrs = [s[s.index("~") + 2 :] for s in rv_reprs] + maxlen = str(max(len(x) for x in names)) + rv_reprs = [ + ("{name:>" + maxlen + "} ~ {distr}").format(name=n, distr=d) + for n, d in zip(names, distrs) + ] + return "\n".join(rv_reprs) + + +def _str_for_input_var(var: Variable, formatting: str) -> str: + # note we're dispatching both on type(var) and on type(var.owner.op) so cannot + # use the standard functools.singledispatch + if isinstance(var, TensorConstant): + return _str_for_constant(var, formatting) + elif isinstance(var.owner.op, RandomVariable): + return _str_for_input_rv(var, formatting) + elif isinstance(var.owner.op, DimShuffle): + return _str_for_input_var(var.owner.inputs[0], formatting) + else: + return _str_for_expression(var, formatting) + + +def _str_for_input_rv(var: Variable, formatting: str) -> str: + _str = var.name if var.name is not None else "" + if "latex" in formatting: + return r"\text{" + _latex_escape(_str) + "}" + else: + return _str + + +def _str_for_constant(var: TensorConstant, formatting: str) -> str: + if len(var.data.shape) == 0: + return f"{var.data:.3g}" + elif len(var.data.shape) == 1 and var.data.shape[0] == 1: + return f"{var.data[0]:.3g}" + elif "latex" in formatting: + return r"\text{}" + else: + return r"" + + +def _str_for_expression(var: Variable, formatting: str) -> str: + # construct a string like f(a1, ..., aN) listing all random variables a as arguments + def _expand(x): + if x.owner and (not isinstance(x.owner.op, RandomVariable)): + return reversed(x.owner.inputs) + + parents = [ + x + for x in walk(nodes=var.owner.inputs, expand=_expand) + if x.owner and isinstance(x.owner.op, RandomVariable) + ] + names = [x.name for x in parents] + + if "latex" in formatting: + return r"f(" + ",~".join([r"\text{" + _latex_escape(n) + "}" for n in names]) + ")" + else: + return r"f(" + ", ".join(names) + ")" + + +def _latex_escape(text: str) -> str: + # Note that this is *NOT* a proper LaTeX escaper, on purpose. _repr_latex_ is + # primarily used in the context of Jupyter notebooks, which render using MathJax. + # MathJax is a subset of LaTeX proper, which expects only $ to be escaped. If we were + # to also escape e.g. _ (replace with \_), then "\_" will show up in the output, etc. + return text.replace("$", r"\$") + + +def _default_repr_pretty(obj: Union[TensorVariable, Model], p, cycle): + """Handy plug-in method to instruct IPython-like REPLs to use our str_repr below.""" + # we know that our str_repr does not recurse, so we can ignore cycle + try: + p.text(obj.str_repr()) + except AttributeError: + # the default fallback option + IPython.lib.pretty._repr_pprint(obj, p, cycle) + + +try: + import IPython + + IPython.lib.pretty.for_type(TensorVariable, _default_repr_pretty) + IPython.lib.pretty.for_type(Model, _default_repr_pretty) +except (ModuleNotFoundError, AttributeError): + # no ipython shell + pass From 747c044da94942f21ecacafe347b89e53ca8460a Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Thu, 8 Jul 2021 07:47:08 +0200 Subject: [PATCH 02/13] removing obsolete repr/latex code --- pymc3/distributions/bart.py | 13 ----- pymc3/distributions/bound.py | 19 -------- pymc3/distributions/distribution.py | 74 +---------------------------- pymc3/distributions/simulator.py | 14 ------ pymc3/model.py | 41 ---------------- pymc3/util.py | 38 --------------- 6 files changed, 1 insertion(+), 198 deletions(-) diff --git a/pymc3/distributions/bart.py b/pymc3/distributions/bart.py index d2e8429296..6d4256c323 100644 --- a/pymc3/distributions/bart.py +++ b/pymc3/distributions/bart.py @@ -282,16 +282,3 @@ class BART(BaseBART): def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None): super().__init__(X, Y, m, alpha, split_prior) - - def _str_repr(self, name=None, dist=None, formatting="plain"): - if dist is None: - dist = self - X = (type(self.X),) - Y = (type(self.Y),) - alpha = self.alpha - m = self.m - - if "latex" in formatting: - return f"$\\text{{{name}}} \\sim \\text{{BART}}(\\text{{alpha = }}\\text{{{alpha}}}, \\text{{m = }}\\text{{{m}}})$" - else: - return f"{name} ~ BART(alpha = {alpha}, m = {m})" diff --git a/pymc3/distributions/bound.py b/pymc3/distributions/bound.py index bbb19d5065..89f02bdcfb 100644 --- a/pymc3/distributions/bound.py +++ b/pymc3/distributions/bound.py @@ -143,25 +143,6 @@ def random(self, point=None, size=None): # ) pass - def _distr_parameters_for_repr(self): - return ["lower", "upper"] - - def _distr_name_for_repr(self): - return "Bound" - - def _str_repr(self, **kwargs): - distr_repr = self._wrapped._str_repr(**{**kwargs, "dist": self._wrapped}) - if "formatting" in kwargs and "latex" in kwargs["formatting"]: - distr_repr = distr_repr[distr_repr.index(r" \sim") + 6 :] - else: - distr_repr = distr_repr[distr_repr.index(" ~") + 3 :] - self_repr = super()._str_repr(**kwargs) - - if "formatting" in kwargs and "latex" in kwargs["formatting"]: - return self_repr + " -- " + distr_repr - else: - return self_repr + "-" + distr_repr - class _DiscreteBounded(_Bounded, Discrete): def __init__(self, distribution, lower, upper, transform="infer", *args, **kwargs): diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index f47568de7b..3424801d5e 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -13,7 +13,6 @@ # limitations under the License. import contextvars import functools -import inspect import multiprocessing import sys import types @@ -43,7 +42,7 @@ resize_from_observed, ) from pymc3.printing import str_repr -from pymc3.util import UNSET, get_repr_for_variable +from pymc3.util import UNSET from pymc3.vartypes import string_types __all__ = [ @@ -325,77 +324,6 @@ def dist( return rv_out - def _distr_parameters_for_repr(self): - """Return the names of the parameters for this distribution (e.g. "mu" - and "sigma" for Normal). Used in generating string (and LaTeX etc.) - representations of Distribution objects. By default based on inspection - of __init__, but can be overwritten if necessary (e.g. to avoid including - "sd" and "tau"). - """ - return inspect.getfullargspec(self.__init__).args[1:] - - def _distr_name_for_repr(self): - return self.__class__.__name__ - - def _str_repr(self, name=None, dist=None, formatting="plain"): - """ - Generate string representation for this distribution, optionally - including LaTeX markup (formatting='latex'). - - Parameters - ---------- - name : str - name of the distribution - dist : Distribution - the distribution object - formatting : str - one of { "latex", "plain", "latex_with_params", "plain_with_params" } - """ - if dist is None: - dist = self - if name is None: - name = "[unnamed]" - supported_formattings = {"latex", "plain", "latex_with_params", "plain_with_params"} - if not formatting in supported_formattings: - raise ValueError(f"Unsupported formatting ''. Choose one of {supported_formattings}.") - - param_names = self._distr_parameters_for_repr() - param_values = [ - get_repr_for_variable(getattr(dist, x), formatting=formatting) for x in param_names - ] - - if "latex" in formatting: - param_string = ",~".join( - [fr"\mathit{{{name}}}={value}" for name, value in zip(param_names, param_values)] - ) - if formatting == "latex_with_params": - return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}({params})$".format( - var_name=name, distr_name=dist._distr_name_for_repr(), params=param_string - ) - return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}$".format( - var_name=name, distr_name=dist._distr_name_for_repr() - ) - else: - # one of the plain formattings - param_string = ", ".join( - [f"{name}={value}" for name, value in zip(param_names, param_values)] - ) - if formatting == "plain_with_params": - return f"{name} ~ {dist._distr_name_for_repr()}({param_string})" - return f"{name} ~ {dist._distr_name_for_repr()}" - - def __str__(self, **kwargs): - try: - return self._str_repr(formatting="plain", **kwargs) - except: - return super().__str__() - - def _repr_latex_(self, *, formatting="latex_with_params", **kwargs): - """Magic method name for IPython to use for LaTeX formatting.""" - return self._str_repr(formatting=formatting, **kwargs) - - __latex__ = _repr_latex_ - class NoDistribution(Distribution): def __init__( diff --git a/pymc3/distributions/simulator.py b/pymc3/distributions/simulator.py index 8b5951b1ad..0b0fba1d30 100644 --- a/pymc3/distributions/simulator.py +++ b/pymc3/distributions/simulator.py @@ -121,20 +121,6 @@ def random(self, point=None, size=None): # else: # return np.array([self.function(*params) for _ in range(size[0])]) - def _str_repr(self, name=None, dist=None, formatting="plain"): - if dist is None: - dist = self - name = name - function = dist.function.__name__ - params = ", ".join([var.name for var in dist.params]) - sum_stat = self.sum_stat.__name__ if hasattr(self.sum_stat, "__call__") else self.sum_stat - distance = getattr(self.distance, "__name__", self.distance.__class__.__name__) - - if "latex" in formatting: - return f"$\\text{{{name}}} \\sim \\text{{Simulator}}(\\text{{{function}}}({params}), \\text{{{distance}}}, \\text{{{sum_stat}}})$" - else: - return f"{name} ~ Simulator({function}({params}), {distance}, {sum_stat})" - def identity(x): """Identity function, used as a summary statistics.""" diff --git a/pymc3/model.py b/pymc3/model.py index 962417dd6c..ee682a3454 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -13,7 +13,6 @@ # limitations under the License. import collections -import itertools import threading import types import warnings @@ -1633,46 +1632,6 @@ def point_logps(self, point=None, round_vals=2): name="Log-probability of test_point", ) - def _str_repr(self, formatting="plain", **kwargs): - all_rv = itertools.chain(self.unobserved_RVs, self.observed_RVs) - - if "latex" in formatting: - rv_reprs = [rv.__latex__(formatting=formatting) for rv in all_rv] - rv_reprs = [ - rv_repr.replace(r"\sim", r"&\sim &").strip("$") - for rv_repr in rv_reprs - if rv_repr is not None - ] - return r"""$$ - \begin{{array}}{{rcl}} - {} - \end{{array}} - $$""".format( - "\\\\".join(rv_reprs) - ) - else: - rv_reprs = [rv.__str__() for rv in all_rv] - rv_reprs = [ - rv_repr for rv_repr in rv_reprs if "TransformedDistribution()" not in rv_repr - ] - # align vars on their ~ - names = [s[: s.index("~") - 1] for s in rv_reprs] - distrs = [s[s.index("~") + 2 :] for s in rv_reprs] - maxlen = str(max(len(x) for x in names)) - rv_reprs = [ - ("{name:>" + maxlen + "} ~ {distr}").format(name=n, distr=d) - for n, d in zip(names, distrs) - ] - return "\n".join(rv_reprs) - - def __str__(self, **kwargs): - return self._str_repr(formatting="plain", **kwargs) - - def _repr_latex_(self, *, formatting="latex", **kwargs): - return self._str_repr(formatting=formatting, **kwargs) - - __latex__ = _repr_latex_ - # this is really disgusting, but it breaks a self-loop: I can't pass Model # itself as context class init arg. diff --git a/pymc3/util.py b/pymc3/util.py index 13ca788286..9f1dd13274 100644 --- a/pymc3/util.py +++ b/pymc3/util.py @@ -218,44 +218,6 @@ def get_default_varnames(var_iterator, include_transformed): return [var for var in var_iterator if not is_transformed_name(get_var_name(var))] -def get_repr_for_variable(variable, formatting="plain"): - """Build a human-readable string representation for a variable.""" - if variable is not None and hasattr(variable, "name"): - name = variable.name - elif type(variable) in [float, int, str]: - name = str(variable) - else: - name = None - - if name is None and variable is not None: - if hasattr(variable, "get_parents"): - try: - names = [ - get_repr_for_variable(item, formatting=formatting) - for item in variable.get_parents()[0].inputs - ] - # do not escape_latex these, since it is not idempotent - if "latex" in formatting: - return "f({args})".format( - args=",~".join([n for n in names if isinstance(n, str)]) - ) - else: - return "f({args})".format( - args=", ".join([n for n in names if isinstance(n, str)]) - ) - except IndexError: - pass - value = variable.eval() - if not value.shape or value.shape == (1,): - return value.item() - return "array" - - if "latex" in formatting: - return fr"\text{{{name}}}" - else: - return name - - def get_var_name(var): """Get an appropriate, plain variable name for a variable.""" return getattr(var, "name", str(var)) From f9dcbf61bd57e7594a040167bfee352fe9e3143e Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Thu, 8 Jul 2021 07:56:39 +0200 Subject: [PATCH 03/13] add latex support to Model, use PrettyPrinter.break_ --- pymc3/model.py | 2 ++ pymc3/printing.py | 15 ++++++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/pymc3/model.py b/pymc3/model.py index ee682a3454..767a431815 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +import functools import threading import types import warnings @@ -671,6 +672,7 @@ def __init__( from pymc3.printing import str_repr self.str_repr = types.MethodType(str_repr, self) + self._repr_latex_ = types.MethodType(functools.partial(str_repr, formatting="latex"), self) @property def model(self): diff --git a/pymc3/printing.py b/pymc3/printing.py index 3c9bf7d0ec..aafc1be96c 100644 --- a/pymc3/printing.py +++ b/pymc3/printing.py @@ -144,16 +144,25 @@ def _latex_escape(text: str) -> str: def _default_repr_pretty(obj: Union[TensorVariable, Model], p, cycle): - """Handy plug-in method to instruct IPython-like REPLs to use our str_repr below.""" + """Handy plug-in method to instruct IPython-like REPLs to use our str_repr above.""" # we know that our str_repr does not recurse, so we can ignore cycle try: - p.text(obj.str_repr()) + output = obj.str_repr() + # Find newlines and replace them with p.break_() + # (see IPython.lib.pretty._repr_pprint) + lines = output.splitlines() + with p.group(): + for idx, output_line in enumerate(lines): + if idx: + p.break_() + p.text(output_line) except AttributeError: - # the default fallback option + # the default fallback option (no str_repr method) IPython.lib.pretty._repr_pprint(obj, p, cycle) try: + # register our custom pretty printer in ipython shells import IPython IPython.lib.pretty.for_type(TensorVariable, _default_repr_pretty) From 3c7a557eddbfcfa64dd4f4c5a41ff8c3bccdc020 Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Thu, 8 Jul 2021 08:32:37 +0200 Subject: [PATCH 04/13] more appropriate type hint (TensorVariable) --- pymc3/printing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/printing.py b/pymc3/printing.py index aafc1be96c..c0f12ae8d4 100644 --- a/pymc3/printing.py +++ b/pymc3/printing.py @@ -27,7 +27,7 @@ @singledispatch -def str_repr(rv: RandomVariable, formatting: str = "plain", include_params: bool = True) -> str: +def str_repr(rv: TensorVariable, formatting: str = "plain", include_params: bool = True) -> str: """Make a human-readable string representation of a RandomVariable in a model, either LaTeX or plain, optionally with distribution parameter values included.""" From a58ca644c78231458434a7f54818f9b0288478da Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Thu, 8 Jul 2021 10:34:29 +0200 Subject: [PATCH 05/13] remove obsolete escape_latex --- pymc3/util.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/pymc3/util.py b/pymc3/util.py index 9f1dd13274..e470721766 100644 --- a/pymc3/util.py +++ b/pymc3/util.py @@ -13,7 +13,6 @@ # limitations under the License. import functools -import re import warnings from typing import Dict, List, Tuple, Union @@ -25,8 +24,6 @@ from cachetools import LRUCache, cachedmethod -LATEX_ESCAPE_RE = re.compile(r"(%|_|\$|#|&)", re.MULTILINE) - UNSET = object() @@ -118,30 +115,6 @@ def tree_contains(self, item): return dict.__contains__(self, item) -def escape_latex(strng): - r"""Consistently escape LaTeX special characters for _repr_latex_ in IPython - - Implementation taken from the IPython magic `format_latex` - - Examples - -------- - escape_latex('disease_rate') # 'disease\_rate' - - Parameters - ---------- - strng: str - string to escape LaTeX characters - - Returns - ------- - str - A string with LaTeX escaped - """ - if strng is None: - return "None" - return LATEX_ESCAPE_RE.sub(r"\\\1", strng) - - def get_transformed_name(name, transform): r""" Consistent way of transforming names From cd89ad6336426d9de2fe9bfe6caa322352571776 Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Thu, 8 Jul 2021 11:45:33 +0200 Subject: [PATCH 06/13] add pretty str/latex for Deterministic and Potential --- pymc3/distributions/distribution.py | 6 ++-- pymc3/model.py | 23 +++++++++++-- pymc3/printing.py | 52 +++++++++++++++++++++++++---- 3 files changed, 69 insertions(+), 12 deletions(-) diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 3424801d5e..42c33a1afe 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -41,7 +41,7 @@ resize_from_dims, resize_from_observed, ) -from pymc3.printing import str_repr +from pymc3.printing import str_for_dist from pymc3.util import UNSET from pymc3.vartypes import string_types @@ -228,9 +228,9 @@ def __new__( ) # add in pretty-printing support - rv_out.str_repr = types.MethodType(str_repr, rv_out) + rv_out.str_repr = types.MethodType(str_for_dist, rv_out) rv_out._repr_latex_ = types.MethodType( - functools.partial(str_repr, formatting="latex"), rv_out + functools.partial(str_for_dist, formatting="latex"), rv_out ) return rv_out diff --git a/pymc3/model.py b/pymc3/model.py index 767a431815..68d4acafca 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -669,10 +669,12 @@ def __init__( self.deterministics = treelist() self.potentials = treelist() - from pymc3.printing import str_repr + from pymc3.printing import str_for_model - self.str_repr = types.MethodType(str_repr, self) - self._repr_latex_ = types.MethodType(functools.partial(str_repr, formatting="latex"), self) + self.str_repr = types.MethodType(str_for_model, self) + self._repr_latex_ = types.MethodType( + functools.partial(str_for_model, formatting="latex"), self + ) @property def model(self): @@ -1787,6 +1789,13 @@ def Deterministic(name, var, model=None, dims=None, auto=False): model.deterministics.append(var) model.add_random_variable(var, dims) + from pymc3.printing import str_for_deterministic + + var.str_repr = types.MethodType(str_for_deterministic, var) + var._repr_latex_ = types.MethodType( + functools.partial(str_for_deterministic, formatting="latex"), var + ) + return var @@ -1807,4 +1816,12 @@ def Potential(name, var, model=None): var.tag.scaling = None model.potentials.append(var) model.add_random_variable(var) + + from pymc3.printing import str_for_potential + + var.str_repr = types.MethodType(str_for_potential, var) + var._repr_latex_ = types.MethodType( + functools.partial(str_for_potential, formatting="latex"), var + ) + return var diff --git a/pymc3/printing.py b/pymc3/printing.py index c0f12ae8d4..74aa468eda 100644 --- a/pymc3/printing.py +++ b/pymc3/printing.py @@ -14,7 +14,6 @@ import itertools -from functools import singledispatch from typing import Union from aesara.graph.basic import walk @@ -26,8 +25,7 @@ from pymc3.model import Model -@singledispatch -def str_repr(rv: TensorVariable, formatting: str = "plain", include_params: bool = True) -> str: +def str_for_dist(rv: TensorVariable, formatting: str = "plain", include_params: bool = True) -> str: """Make a human-readable string representation of a RandomVariable in a model, either LaTeX or plain, optionally with distribution parameter values included.""" @@ -51,11 +49,10 @@ def str_repr(rv: TensorVariable, formatting: str = "plain", include_params: bool return fr"{print_name} ~ {dist_name}" -@str_repr.register -def _(model: Model, formatting: str = "plain", include_params: bool = True) -> str: +def str_for_model(model: Model, formatting: str = "plain", include_params: bool = True) -> str: """Make a human-readable string representation of Model, listing all random variables and their distributions, optionally including parameter values.""" - all_rv = itertools.chain(model.unobserved_RVs, model.observed_RVs) + all_rv = itertools.chain(model.unobserved_RVs, model.observed_RVs, model.potentials) rv_reprs = [rv.str_repr(formatting=formatting, include_params=include_params) for rv in all_rv] rv_reprs = [rv_repr for rv_repr in rv_reprs if "TransformedDistribution()" not in rv_repr] @@ -84,6 +81,44 @@ def _(model: Model, formatting: str = "plain", include_params: bool = True) -> s return "\n".join(rv_reprs) +def str_for_deterministic( + var: TensorVariable, formatting: str = "plain", include_params: bool = True +) -> str: + print_name = var.name if var.name is not None else "" + if "latex" in formatting: + print_name = r"\text{" + _latex_escape(print_name) + "}" + if include_params: + return fr"${print_name} \sim Deterministic[{_str_for_expression(var, formatting=formatting)}]$" + else: + return fr"${print_name} \sim Deterministic$" + else: # plain + if include_params: + return ( + fr"{print_name} ~ Deterministic[{_str_for_expression(var, formatting=formatting)}]" + ) + else: + return fr"{print_name} ~ Deterministic" + + +def str_for_potential( + var: TensorVariable, formatting: str = "plain", include_params: bool = True +) -> str: + print_name = var.name if var.name is not None else "" + if "latex" in formatting: + print_name = r"\text{" + _latex_escape(print_name) + "}" + if include_params: + return ( + fr"${print_name} \sim Potential[{_str_for_expression(var, formatting=formatting)}]$" + ) + else: + return fr"${print_name} \sim Potential$" + else: # plain + if include_params: + return fr"{print_name} ~ Potential[{_str_for_expression(var, formatting=formatting)}]" + else: + return fr"{print_name} ~ Potential" + + def _str_for_input_var(var: Variable, formatting: str) -> str: # note we're dispatching both on type(var) and on type(var.owner.op) so cannot # use the standard functools.singledispatch @@ -93,6 +128,11 @@ def _str_for_input_var(var: Variable, formatting: str) -> str: return _str_for_input_rv(var, formatting) elif isinstance(var.owner.op, DimShuffle): return _str_for_input_var(var.owner.inputs[0], formatting) + elif hasattr(var, "str_repr") and ( + var.str_repr.__func__ is str_for_deterministic or var.str_repr.__func__ is str_for_potential + ): + # display the name for a Deterministic or Potential, rather than the full expression + return var.name else: return _str_for_expression(var, formatting) From dd399e9d6302e7b4d73251a609ce43ff746bc4cd Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Thu, 8 Jul 2021 12:02:12 +0200 Subject: [PATCH 07/13] update escape characters in latex formats --- pymc3/printing.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pymc3/printing.py b/pymc3/printing.py index 74aa468eda..1ec56d4348 100644 --- a/pymc3/printing.py +++ b/pymc3/printing.py @@ -88,9 +88,9 @@ def str_for_deterministic( if "latex" in formatting: print_name = r"\text{" + _latex_escape(print_name) + "}" if include_params: - return fr"${print_name} \sim Deterministic[{_str_for_expression(var, formatting=formatting)}]$" + return fr"${print_name} \sim \operatorname{{Deterministic}}[{_str_for_expression(var, formatting=formatting)}]$" else: - return fr"${print_name} \sim Deterministic$" + return fr"${print_name} \sim \operatorname{{Deterministic}}$" else: # plain if include_params: return ( @@ -107,11 +107,9 @@ def str_for_potential( if "latex" in formatting: print_name = r"\text{" + _latex_escape(print_name) + "}" if include_params: - return ( - fr"${print_name} \sim Potential[{_str_for_expression(var, formatting=formatting)}]$" - ) + return fr"${print_name} \sim \operatorname{{Potential}}[{_str_for_expression(var, formatting=formatting)}]$" else: - return fr"${print_name} \sim Potential$" + return fr"${print_name} \sim \operatorname{{Potential}}$" else: # plain if include_params: return fr"{print_name} ~ Potential[{_str_for_expression(var, formatting=formatting)}]" From 7f6bc9a020781caa840a756efc0ee0b4a5cadcba Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Fri, 9 Jul 2021 10:37:50 +0200 Subject: [PATCH 08/13] small refactor --- pymc3/printing.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/pymc3/printing.py b/pymc3/printing.py index 1ec56d4348..26e97a95ac 100644 --- a/pymc3/printing.py +++ b/pymc3/printing.py @@ -122,15 +122,18 @@ def _str_for_input_var(var: Variable, formatting: str) -> str: # use the standard functools.singledispatch if isinstance(var, TensorConstant): return _str_for_constant(var, formatting) - elif isinstance(var.owner.op, RandomVariable): + elif isinstance(var.owner.op, RandomVariable) or ( + hasattr(var, "str_repr") + and ( + var.str_repr.__func__ is str_for_deterministic + or var.str_repr.__func__ is str_for_potential + ) + ): + # show the names for RandomVariables, Deterministics, and Potentials, rather + # than the full expression return _str_for_input_rv(var, formatting) elif isinstance(var.owner.op, DimShuffle): return _str_for_input_var(var.owner.inputs[0], formatting) - elif hasattr(var, "str_repr") and ( - var.str_repr.__func__ is str_for_deterministic or var.str_repr.__func__ is str_for_potential - ): - # display the name for a Deterministic or Potential, rather than the full expression - return var.name else: return _str_for_expression(var, formatting) From 75045b38a70619787bfeeca4c5c3d05da4013a8c Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Fri, 9 Jul 2021 10:55:54 +0200 Subject: [PATCH 09/13] refactor: unify str_for_potential_or_deterministic --- pymc3/model.py | 22 ++++++++++++++++------ pymc3/printing.py | 46 ++++++++++++++-------------------------------- 2 files changed, 30 insertions(+), 38 deletions(-) diff --git a/pymc3/model.py b/pymc3/model.py index 68d4acafca..74f75d833a 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -1789,11 +1789,16 @@ def Deterministic(name, var, model=None, dims=None, auto=False): model.deterministics.append(var) model.add_random_variable(var, dims) - from pymc3.printing import str_for_deterministic + from pymc3.printing import str_for_potential_or_deterministic - var.str_repr = types.MethodType(str_for_deterministic, var) + var.str_repr = types.MethodType( + functools.partial(str_for_potential_or_deterministic, dist_name="Deterministic"), var + ) var._repr_latex_ = types.MethodType( - functools.partial(str_for_deterministic, formatting="latex"), var + functools.partial( + str_for_potential_or_deterministic, dist_name="Deterministic", formatting="latex" + ), + var, ) return var @@ -1817,11 +1822,16 @@ def Potential(name, var, model=None): model.potentials.append(var) model.add_random_variable(var) - from pymc3.printing import str_for_potential + from pymc3.printing import str_for_potential_or_deterministic - var.str_repr = types.MethodType(str_for_potential, var) + var.str_repr = types.MethodType( + functools.partial(str_for_potential_or_deterministic, dist_name="Potential"), var + ) var._repr_latex_ = types.MethodType( - functools.partial(str_for_potential, formatting="latex"), var + functools.partial( + str_for_potential_or_deterministic, dist_name="Potential", formatting="latex" + ), + var, ) return var diff --git a/pymc3/printing.py b/pymc3/printing.py index 26e97a95ac..08d8f8add3 100644 --- a/pymc3/printing.py +++ b/pymc3/printing.py @@ -81,54 +81,36 @@ def str_for_model(model: Model, formatting: str = "plain", include_params: bool return "\n".join(rv_reprs) -def str_for_deterministic( - var: TensorVariable, formatting: str = "plain", include_params: bool = True +def str_for_potential_or_deterministic( + var: TensorVariable, dist_name: str, formatting: str = "plain", include_params: bool = True ) -> str: print_name = var.name if var.name is not None else "" if "latex" in formatting: print_name = r"\text{" + _latex_escape(print_name) + "}" if include_params: - return fr"${print_name} \sim \operatorname{{Deterministic}}[{_str_for_expression(var, formatting=formatting)}]$" + return fr"${print_name} \sim \operatorname{{{dist_name}}}({_str_for_expression(var, formatting=formatting)})$" else: - return fr"${print_name} \sim \operatorname{{Deterministic}}$" + return fr"${print_name} \sim \operatorname{{{dist_name}}}$" else: # plain if include_params: - return ( - fr"{print_name} ~ Deterministic[{_str_for_expression(var, formatting=formatting)}]" - ) + return fr"{print_name} ~ {dist_name}({_str_for_expression(var, formatting=formatting)})" else: - return fr"{print_name} ~ Deterministic" - - -def str_for_potential( - var: TensorVariable, formatting: str = "plain", include_params: bool = True -) -> str: - print_name = var.name if var.name is not None else "" - if "latex" in formatting: - print_name = r"\text{" + _latex_escape(print_name) + "}" - if include_params: - return fr"${print_name} \sim \operatorname{{Potential}}[{_str_for_expression(var, formatting=formatting)}]$" - else: - return fr"${print_name} \sim \operatorname{{Potential}}$" - else: # plain - if include_params: - return fr"{print_name} ~ Potential[{_str_for_expression(var, formatting=formatting)}]" - else: - return fr"{print_name} ~ Potential" + return fr"{print_name} ~ {dist_name}" def _str_for_input_var(var: Variable, formatting: str) -> str: # note we're dispatching both on type(var) and on type(var.owner.op) so cannot # use the standard functools.singledispatch + + def _is_potential_or_determinstic(var: Variable) -> bool: + return ( + hasattr(var, "str_repr") + and var.str_repr.__func__.func is str_for_potential_or_deterministic + ) + if isinstance(var, TensorConstant): return _str_for_constant(var, formatting) - elif isinstance(var.owner.op, RandomVariable) or ( - hasattr(var, "str_repr") - and ( - var.str_repr.__func__ is str_for_deterministic - or var.str_repr.__func__ is str_for_potential - ) - ): + elif isinstance(var.owner.op, RandomVariable) or _is_potential_or_determinstic(var): # show the names for RandomVariables, Deterministics, and Potentials, rather # than the full expression return _str_for_input_rv(var, formatting) From 8818b71fc23647c64048ef00cf2d583870ddf064 Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Fri, 9 Jul 2021 10:59:43 +0200 Subject: [PATCH 10/13] refactor: safer fallback if user code changes str_repr --- pymc3/printing.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pymc3/printing.py b/pymc3/printing.py index 08d8f8add3..777c092032 100644 --- a/pymc3/printing.py +++ b/pymc3/printing.py @@ -99,14 +99,12 @@ def str_for_potential_or_deterministic( def _str_for_input_var(var: Variable, formatting: str) -> str: - # note we're dispatching both on type(var) and on type(var.owner.op) so cannot - # use the standard functools.singledispatch - def _is_potential_or_determinstic(var: Variable) -> bool: - return ( - hasattr(var, "str_repr") - and var.str_repr.__func__.func is str_for_potential_or_deterministic - ) + try: + return var.str_repr.__func__.func is str_for_potential_or_deterministic + except AttributeError: + # in case other code overrides str_repr, fallback + return False if isinstance(var, TensorConstant): return _str_for_constant(var, formatting) From 86194a71dafaec55a797ddab649d5c60a3d802e3 Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Fri, 9 Jul 2021 13:32:06 +0200 Subject: [PATCH 11/13] update tests, add __all__ --- pymc3/printing.py | 13 ++- pymc3/tests/test_distributions.py | 166 ++++++++++++++---------------- 2 files changed, 87 insertions(+), 92 deletions(-) diff --git a/pymc3/printing.py b/pymc3/printing.py index 777c092032..b0889a167f 100644 --- a/pymc3/printing.py +++ b/pymc3/printing.py @@ -24,6 +24,12 @@ from pymc3.model import Model +__all__ = [ + "str_for_dist", + "str_for_model", + "str_for_potential_or_deterministic", +] + def str_for_dist(rv: TensorVariable, formatting: str = "plain", include_params: bool = True) -> str: """Make a human-readable string representation of a RandomVariable in a model, either @@ -82,8 +88,13 @@ def str_for_model(model: Model, formatting: str = "plain", include_params: bool def str_for_potential_or_deterministic( - var: TensorVariable, dist_name: str, formatting: str = "plain", include_params: bool = True + var: TensorVariable, + formatting: str = "plain", + include_params: bool = True, + dist_name: str = "Deterministic", ) -> str: + """Make a human-readable string representation of a Deterministic or Potential in a model, either + LaTeX or plain, optionally with distribution parameter values included.""" print_name = var.name if var.name is not None else "" if "latex" in formatting: print_name = r"\text{" + _latex_escape(print_name) + "}" diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 2eab4b88d1..27d1998b41 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -104,7 +104,7 @@ logpt_sum, ) from pymc3.math import kronecker -from pymc3.model import Deterministic, Model, Point +from pymc3.model import Deterministic, Model, Point, Potential from pymc3.tests.helpers import select_by_precision from pymc3.vartypes import continuous_types @@ -2786,7 +2786,6 @@ def test_lower_bounded_broadcasted(self): assert upper_interval is None -@pytest.mark.xfail(reason="LaTeX repr and str no longer applicable") class TestStrAndLatexRepr: def setup_class(self): # True parameter values @@ -2802,6 +2801,9 @@ def setup_class(self): # Simulate outcome variable Y = alpha + X.dot(beta) + np.random.randn(size) * sigma with Model() as self.model: + # TODO: some variables commented out here as they're not working properly + # in v4 yet (9-jul-2021), so doesn't make sense to test str/latex for them + # Priors for unknown model parameters alpha = Normal("alpha", mu=0, sigma=10) b = Normal("beta", mu=0, sigma=10, size=(2,), observed=beta) @@ -2811,16 +2813,16 @@ def setup_class(self): Z = MvNormal("Z", mu=np.zeros(2), chol=np.eye(2), size=(2,)) # NegativeBinomial representations to test issue 4186 - nb1 = pm.NegativeBinomial( - "nb_with_mu_alpha", mu=pm.Normal("nbmu"), alpha=pm.Gamma("nbalpha", mu=6, sigma=1) - ) + # nb1 = pm.NegativeBinomial( + # "nb_with_mu_alpha", mu=pm.Normal("nbmu"), alpha=pm.Gamma("nbalpha", mu=6, sigma=1) + # ) nb2 = pm.NegativeBinomial("nb_with_p_n", p=pm.Uniform("nbp"), n=10) # Expected value of outcome mu = Deterministic("mu", floatX(alpha + at.dot(X, b))) # add a bounded variable as well - bound_var = Bound(Normal, lower=1.0)("bound_var", mu=0, sigma=10) + # bound_var = Bound(Normal, lower=1.0)("bound_var", mu=0, sigma=10) # KroneckerNormal n, m = 3, 4 @@ -2828,13 +2830,13 @@ def setup_class(self): kron_normal = KroneckerNormal("kron_normal", mu=np.zeros(n * m), covs=covs, size=n * m) # MatrixNormal - matrix_normal = MatrixNormal( - "mat_normal", - mu=np.random.normal(size=n), - rowcov=np.eye(n), - colchol=np.linalg.cholesky(np.eye(n)), - size=(n, n), - ) + # matrix_normal = MatrixNormal( + # "mat_normal", + # mu=np.random.normal(size=n), + # rowcov=np.eye(n), + # colchol=np.linalg.cholesky(np.eye(n)), + # size=(n, n), + # ) # DirichletMultinomial dm = DirichletMultinomial("dm", n=5, a=[1, 1, 1], size=(2, 3)) @@ -2842,97 +2844,79 @@ def setup_class(self): # Likelihood (sampling distribution) of observations Y_obs = Normal("Y_obs", mu=mu, sigma=sigma, observed=Y) - self.distributions = [alpha, sigma, mu, b, Z, nb1, nb2, Y_obs, bound_var] + # add a potential as well + pot = Potential("pot", mu ** 2) + + self.distributions = [alpha, sigma, mu, b, Z, nb2, Y_obs, pot] + self.deterministics_or_potentials = [mu, pot] + # tuples of (formatting, include_params + self.formats = [("plain", True), ("plain", False), ("latex", True), ("latex", False)] self.expected = { - "latex": ( - r"$\text{alpha} \sim \text{Normal}$", - r"$\text{sigma} \sim \text{HalfNormal}$", - r"$\text{mu} \sim \text{Deterministic}$", - r"$\text{beta} \sim \text{Normal}$", - r"$\text{Z} \sim \text{MvNormal}$", - r"$\text{nb_with_mu_alpha} \sim \text{NegativeBinomial}$", - r"$\text{nb_with_p_n} \sim \text{NegativeBinomial}$", - r"$\text{Y_obs} \sim \text{Normal}$", - r"$\text{bound_var} \sim \text{Bound}$ -- \text{Normal}$", - r"$\text{kron_normal} \sim \text{KroneckerNormal}$", - r"$\text{mat_normal} \sim \text{MatrixNormal}$", - r"$\text{dm} \sim \text{DirichletMultinomial}$", - ), - "plain": ( - r"alpha ~ Normal", - r"sigma ~ HalfNormal", + ("plain", True): [ + r"alpha ~ N(0, 10)", + r"sigma ~ N**+(0, 1)", + r"mu ~ Deterministic(f(beta, alpha))", + r"beta ~ N(0, 10)", + r"Z ~ N(, f())", + r"nb_with_p_n ~ NB(10, nbp)", + r"Y_obs ~ N(mu, sigma)", + r"pot ~ Potential(f(beta, alpha))", + ], + ("plain", False): [ + r"alpha ~ N", + r"sigma ~ N**+", r"mu ~ Deterministic", - r"beta ~ Normal", - r"Z ~ MvNormal", - r"nb_with_mu_alpha ~ NegativeBinomial", - r"nb_with_p_n ~ NegativeBinomial", - r"Y_obs ~ Normal", - r"bound_var ~ Bound-Normal", - r"kron_normal ~ KroneckerNormal", - r"mat_normal ~ MatrixNormal", - r"dm ~ DirichletMultinomial", - ), - "latex_with_params": ( - r"$\text{alpha} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$", - r"$\text{sigma} \sim \text{HalfNormal}(\mathit{sigma}=1.0)$", - r"$\text{mu} \sim \text{Deterministic}(\text{alpha},~\text{Constant},~\text{beta})$", - r"$\text{beta} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$", - r"$\text{Z} \sim \text{MvNormal}(\mathit{mu}=array,~\mathit{chol_cov}=array)$", - r"$\text{nb_with_mu_alpha} \sim \text{NegativeBinomial}(\mathit{mu}=\text{nbmu},~\mathit{alpha}=\text{nbalpha})$", - r"$\text{nb_with_p_n} \sim \text{NegativeBinomial}(\mathit{p}=\text{nbp},~\mathit{n}=10)$", - r"$\text{Y_obs} \sim \text{Normal}(\mathit{mu}=\text{mu},~\mathit{sigma}=f(\text{sigma}))$", - r"$\text{bound_var} \sim \text{Bound}(\mathit{lower}=1.0,~\mathit{upper}=\text{None})$ -- \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$", - r"$\text{kron_normal} \sim \text{KroneckerNormal}(\mathit{mu}=array)$", - r"$\text{mat_normal} \sim \text{MatrixNormal}(\mathit{mu}=array,~\mathit{rowcov}=array,~\mathit{colchol_cov}=array)$", - r"$\text{dm} \sim \text{DirichletMultinomial}(\mathit{n}=5,~\mathit{a}=array)$", - ), - "plain_with_params": ( - r"alpha ~ Normal(mu=0.0, sigma=10.0)", - r"sigma ~ HalfNormal(sigma=1.0)", - r"mu ~ Deterministic(alpha, Constant, beta)", - r"beta ~ Normal(mu=0.0, sigma=10.0)", - r"Z ~ MvNormal(mu=array, chol_cov=array)", - r"nb_with_mu_alpha ~ NegativeBinomial(mu=nbmu, alpha=nbalpha)", - r"nb_with_p_n ~ NegativeBinomial(p=nbp, n=10)", - r"Y_obs ~ Normal(mu=mu, sigma=f(sigma))", - r"bound_var ~ Bound(lower=1.0, upper=None)-Normal(mu=0.0, sigma=10.0)", - r"kron_normal ~ KroneckerNormal(mu=array)", - r"mat_normal ~ MatrixNormal(mu=array, rowcov=array, colchol_cov=array)", - r"dm∼DirichletMultinomial(n=5, a=array)", - ), + r"beta ~ N", + r"Z ~ N", + r"nb_with_p_n ~ NB", + r"Y_obs ~ N", + r"pot ~ Potential", + ], + ("latex", True): [ + r"$\text{alpha} \sim \operatorname{N}(0,~10)$", + r"$\text{sigma} \sim \operatorname{N^{+}}(0,~1)$", + r"$\text{mu} \sim \operatorname{Deterministic}(f(\text{beta},~\text{alpha}))$", + r"$\text{beta} \sim \operatorname{N}(0,~10)$", + r"$\text{Z} \sim \operatorname{N}(\text{},~f())$", + r"$\text{nb_with_p_n} \sim \operatorname{NB}(10,~\text{nbp})$", + r"$\text{Y_obs} \sim \operatorname{N}(\text{mu},~\text{sigma})$", + r"$\text{pot} \sim \operatorname{Potential}(f(\text{beta},~\text{alpha}))$", + ], + ("latex", False): [ + r"$\text{alpha} \sim \operatorname{N}$", + r"$\text{sigma} \sim \operatorname{N^{+}}$", + r"$\text{mu} \sim \operatorname{Deterministic}$", + r"$\text{beta} \sim \operatorname{N}$", + r"$\text{Z} \sim \operatorname{N}$", + r"$\text{nb_with_p_n} \sim \operatorname{NB}$", + r"$\text{Y_obs} \sim \operatorname{N}$", + r"$\text{pot} \sim \operatorname{Potential}$", + ], } def test__repr_latex_(self): - for distribution, tex in zip(self.distributions, self.expected["latex_with_params"]): + for distribution, tex in zip(self.distributions, self.expected[("latex", True)]): assert distribution._repr_latex_() == tex model_tex = self.model._repr_latex_() # make sure each variable is in the model - for tex in self.expected["latex"]: + for tex in self.expected[("latex", True)]: for segment in tex.strip("$").split(r"\sim"): assert segment in model_tex - def test___latex__(self): - for distribution, tex in zip(self.distributions, self.expected["latex_with_params"]): - assert distribution._repr_latex_() == distribution.__latex__() - assert self.model._repr_latex_() == self.model.__latex__() - - def test___str__(self): - for distribution, str_repr in zip(self.distributions, self.expected["plain"]): - assert distribution.__str__() == str_repr - - model_str = self.model.__str__() - for str_repr in self.expected["plain"]: - assert str_repr in model_str - - def test_str(self): - for distribution, str_repr in zip(self.distributions, self.expected["plain"]): - assert str(distribution) == str_repr - - model_str = str(self.model) - for str_repr in self.expected["plain"]: - assert str_repr in model_str + def test_str_repr(self): + for str_format in self.formats: + for dist, text in zip(self.distributions, self.expected[str_format]): + assert dist.str_repr(*str_format) == text + + model_text = self.model.str_repr(*str_format) + for text in self.expected[str_format]: + if str_format[0] == "latex": + for segment in text.strip("$").split(r"\sim"): + assert segment in model_text + else: + assert text in model_text def test_discrete_trafo(): From 6352da0441603f0a58fac626b7070ac7b07156e0 Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Fri, 9 Jul 2021 13:35:02 +0200 Subject: [PATCH 12/13] import printing in root module --- pymc3/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc3/__init__.py b/pymc3/__init__.py index d1a87f08cd..9706b821a6 100644 --- a/pymc3/__init__.py +++ b/pymc3/__init__.py @@ -68,6 +68,7 @@ def __set_compiler_flags(): from pymc3.model import * from pymc3.model_graph import model_to_graphviz from pymc3.plots import * +from pymc3.printing import * from pymc3.sampling import * from pymc3.smc import * from pymc3.stats import * From 78f9da0b1dc5a7cdf815e81f7da40b41bae85602 Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Tue, 13 Jul 2021 16:13:02 +0200 Subject: [PATCH 13/13] use cloudpickle in smc sampling --- pymc3/smc/sample_smc.py | 47 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/pymc3/smc/sample_smc.py b/pymc3/smc/sample_smc.py index e7d8884859..010bd52665 100644 --- a/pymc3/smc/sample_smc.py +++ b/pymc3/smc/sample_smc.py @@ -19,6 +19,7 @@ from collections.abc import Iterable +import cloudpickle import numpy as np from arviz import InferenceData @@ -224,9 +225,12 @@ def sample_smc( pbars = [pbar] + [None] * (chains - 1) pool = mp.Pool(cores) + # "manually" (de)serialize params before/after multiprocessing + params = tuple(cloudpickle.dumps(p) for p in params) results = pool.starmap( - sample_smc_int, [(*params, random_seed[i], i, pbars[i]) for i in range(chains)] + _sample_smc_int, [(*params, random_seed[i], i, pbars[i]) for i in range(chains)] ) + results = tuple(cloudpickle.loads(r) for r in results) pool.close() pool.join() @@ -237,7 +241,7 @@ def sample_smc( for i in range(chains): pbar.offset = 100 * i pbar.base_comment = f"Chain: {i+1}/{chains}" - results.append(sample_smc_int(*params, random_seed[i], i, pbar)) + results.append(_sample_smc_int(*params, random_seed[i], i, pbar)) ( traces, @@ -316,7 +320,7 @@ def sample_smc( return posterior -def sample_smc_int( +def _sample_smc_int( draws, kernel, n_steps, @@ -332,6 +336,36 @@ def sample_smc_int( progressbar=None, ): """Run one SMC instance.""" + in_out_pickled = type(model) == bytes + if in_out_pickled: + # function was called in multiprocessing context, deserialize first + ( + draws, + kernel, + n_steps, + start, + tune_steps, + p_acc_rate, + threshold, + save_sim_data, + save_log_pseudolikelihood, + model, + ) = map( + cloudpickle.loads, + ( + draws, + kernel, + n_steps, + start, + tune_steps, + p_acc_rate, + threshold, + save_sim_data, + save_log_pseudolikelihood, + model, + ), + ) + smc = SMC( draws=draws, kernel=kernel, @@ -375,7 +409,7 @@ def sample_smc_int( accept_ratios.append(smc.acc_rate) nsteps.append(smc.n_steps) - return ( + results = ( smc.posterior_to_trace(), smc.sim_data, smc.log_marginal_likelihood, @@ -384,3 +418,8 @@ def sample_smc_int( accept_ratios, nsteps, ) + + if in_out_pickled: + results = cloudpickle.dumps(results) + + return results