diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d8160ba..00959ef 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,3 +31,7 @@ repos: args: [--rcfile=.pylintrc] exclude: (test_*|mcbackend/meta.py|mcbackend/npproto/) files: ^mcbackend/ +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.991 + hooks: + - id: mypy diff --git a/mcbackend/adapters/pymc.py b/mcbackend/adapters/pymc.py index 8002c78..8412a08 100644 --- a/mcbackend/adapters/pymc.py +++ b/mcbackend/adapters/pymc.py @@ -9,16 +9,10 @@ import hagelkorn import numpy - -try: - from pytensor.graph.basic import Constant - from pytensor.tensor.sharedvar import SharedVariable -except ModuleNotFoundError: - from aesara.graph.basic import Constant - from aesara.tensor.sharedvar import SharedVariable - from pymc.backends.base import BaseTrace from pymc.model import Model +from pytensor.graph.basic import Constant +from pytensor.tensor.sharedvar import SharedVariable from mcbackend.meta import Coordinate, DataVariable, Variable diff --git a/mcbackend/backends/clickhouse.py b/mcbackend/backends/clickhouse.py index e2194c6..a36d437 100644 --- a/mcbackend/backends/clickhouse.py +++ b/mcbackend/backends/clickhouse.py @@ -5,7 +5,18 @@ import logging import time from datetime import datetime, timezone -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple +from typing import ( + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Union, +) import clickhouse_driver import numpy @@ -156,7 +167,7 @@ def __init__( self._client = client # The following attributes belong to the batched insert mechanism. # Inserting in batches is much faster than inserting single rows. - self._str_cols = set() + self._str_cols: Set[str] = set() self._insert_query: str = "" self._insert_queue: List[Dict[str, Any]] = [] self._last_insert = time.time() @@ -176,13 +187,16 @@ def append( self._insert_query = f"INSERT INTO {self.cid} (`_draw_idx`,`{names}`) VALUES" self._str_cols = {k for k, v in params.items() if "str" in numpy.asarray(v).dtype.name} - # Convert str ndarrays to lists + params_ins: Dict[str, Union[numpy.ndarray, int, float, List[str]]] = { + "_draw_idx": self._draw_idx, + **params, + } + # Convert str-dtyped ndarrays to lists for col in self._str_cols: - params[col] = params[col].tolist() + params_ins[col] = params[col].tolist() # Queue up for insertion - params["_draw_idx"] = self._draw_idx - self._insert_queue.append(params) + self._insert_queue.append(params_ins) self._draw_idx += 1 if ( @@ -242,13 +256,14 @@ def _get_rows( # Without draws return empty arrays of the correct shape/dtype if not draws: - if is_rigid(nshape): - return numpy.empty(shape=[0] + nshape, dtype=dtype) + if is_rigid(nshape) and nshape is not None: + return numpy.empty(shape=[0, *nshape], dtype=dtype) return numpy.array([], dtype=object) # The unpacking must also account for non-rigid shapes # and str-dtyped empty arrays default to fixed length 1 strings. # The [None] list is slower, but more flexible in this regard. + buffer: Union[numpy.ndarray, Sequence] if is_rigid(nshape) and dtype != "str": assert nshape is not None buffer = numpy.empty((draws, *nshape), dtype) @@ -292,7 +307,7 @@ def __init__( self, meta: RunMeta, *, - created_at: datetime = None, + created_at: Optional[datetime] = None, client_fn: Callable[[], clickhouse_driver.Client], ) -> None: self._client_fn = client_fn @@ -331,8 +346,8 @@ class ClickHouseBackend(Backend): def __init__( self, - client: clickhouse_driver.Client = None, - client_fn: Callable[[], clickhouse_driver.Client] = None, + client: Optional[clickhouse_driver.Client] = None, + client_fn: Optional[Callable[[], clickhouse_driver.Client]] = None, ): """Create a ClickHouse backend around a database client. diff --git a/mcbackend/core.py b/mcbackend/core.py index e2fce11..2e31015 100644 --- a/mcbackend/core.py +++ b/mcbackend/core.py @@ -3,16 +3,7 @@ """ import collections import logging -from typing import ( - TYPE_CHECKING, - Dict, - List, - Mapping, - Optional, - Sequence, - Sized, - TypeVar, -) +from typing import Dict, List, Mapping, Optional, Sequence, Sized, TypeVar, Union, cast import numpy @@ -20,14 +11,12 @@ from .npproto.utils import ndarray_to_numpy from .utils import as_array_from_ragged -InferenceData = TypeVar("InferenceData") try: - from arviz import from_dict + from arviz import InferenceData, from_dict - if not TYPE_CHECKING: - from arviz import InferenceData _HAS_ARVIZ = True except ModuleNotFoundError: + InferenceData = TypeVar("InferenceData") # type: ignore _HAS_ARVIZ = False Shape = Sequence[int] @@ -262,20 +251,22 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) -> warmup_sample_stats[svar.name].append(stats[tune]) sample_stats[svar.name].append(stats[~tune]) + w_pst = cast(Dict[str, Union[Sequence, numpy.ndarray]], warmup_posterior) + w_ss = cast(Dict[str, Union[Sequence, numpy.ndarray]], warmup_sample_stats) + pst = cast(Dict[str, Union[Sequence, numpy.ndarray]], posterior) + ss = cast(Dict[str, Union[Sequence, numpy.ndarray]], sample_stats) if not equalize_chain_lengths: # Convert ragged arrays to object-dtyped ndarray because NumPy >=1.24.0 no longer does that automatically - warmup_posterior = {k: as_array_from_ragged(v) for k, v in warmup_posterior.items()} - warmup_sample_stats = { - k: as_array_from_ragged(v) for k, v in warmup_sample_stats.items() - } - posterior = {k: as_array_from_ragged(v) for k, v in posterior.items()} - sample_stats = {k: as_array_from_ragged(v) for k, v in sample_stats.items()} + w_pst = {k: as_array_from_ragged(v) for k, v in warmup_posterior.items()} + w_ss = {k: as_array_from_ragged(v) for k, v in warmup_sample_stats.items()} + pst = {k: as_array_from_ragged(v) for k, v in posterior.items()} + ss = {k: as_array_from_ragged(v) for k, v in sample_stats.items()} idata = from_dict( - warmup_posterior=warmup_posterior, - warmup_sample_stats=warmup_sample_stats, - posterior=posterior, - sample_stats=sample_stats, + warmup_posterior=w_pst, + warmup_sample_stats=w_ss, + posterior=pst, + sample_stats=ss, coords=self.coords, dims=self.dims, attrs=self.meta.attributes, diff --git a/mcbackend/test_core.py b/mcbackend/test_core.py index a5ee1fc..ed2066b 100644 --- a/mcbackend/test_core.py +++ b/mcbackend/test_core.py @@ -72,10 +72,10 @@ def test_chain_properties(self): def test_chain_length(self): class _TestChain(core.Chain): - def get_draws(self, var_name: str): + def get_draws(self, var_name: str, slc: slice = slice(None)): return numpy.arange(12) - def get_stats(self, stat_name: str): + def get_stats(self, stat_name: str, slc: slice = slice(None)): return numpy.arange(42) rmeta = RunMeta("test", variables=[Variable("v1")]) diff --git a/mcbackend/test_utils.py b/mcbackend/test_utils.py index ea59b27..c696ea7 100644 --- a/mcbackend/test_utils.py +++ b/mcbackend/test_utils.py @@ -1,7 +1,7 @@ import random import time from dataclasses import dataclass -from typing import Sequence +from typing import Optional, Sequence import arviz import hagelkorn @@ -78,9 +78,9 @@ def make_draw(variables: Sequence[Variable]): class BaseBackendTest: """Can be used to test different backends in the same way.""" - cls_backend = None - cls_run = None - cls_chain = None + cls_backend: Optional[type] = None + cls_run: Optional[type] = None + cls_chain: Optional[type] = None def setup_method(self, method): """Override this when the backend has no parameterless constructor.""" @@ -373,10 +373,8 @@ def run_all_benchmarks(self) -> pandas.DataFrame: for attr in dir(BackendBenchmark): meth = getattr(self, attr, None) if callable(meth) and meth.__name__.startswith("measure_"): - try: + if hasattr(self, "setup_method"): self.setup_method(meth) - except TypeError: - pass print(f"Running {meth.__name__}") speed = meth() df.loc[meth.__name__[8:], ["bytes_per_draw", "append_speed", "description"]] = (