Skip to content

Commit 1b2fe3b

Browse files
Add slc option to Chain.get_draws and get_stats
Closes #47
1 parent 183407d commit 1b2fe3b

File tree

5 files changed

+76
-20
lines changed

5 files changed

+76
-20
lines changed

mcbackend/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@
2020
pass
2121

2222

23-
__version__ = "0.1.3"
23+
__version__ = "0.2.0"

mcbackend/backends/clickhouse.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -166,18 +166,15 @@ def _get_row_at(
166166
result = dict(zip(var_names, data[0][0]))
167167
return result
168168

169-
def _get_rows( # pylint: disable=W0221
169+
def _get_rows(
170170
self,
171171
var_name: str,
172172
nshape: Optional[Sequence[int]],
173173
dtype: str,
174-
*,
175-
burn: int = 0,
174+
slc: slice = slice(None),
176175
) -> numpy.ndarray:
177176
self._commit()
178-
data = self._client.execute(
179-
f"SELECT (`{var_name}`) FROM {self.cid} WHERE _draw_idx>={burn};"
180-
)
177+
data = self._client.execute(f"SELECT (`{var_name}`) FROM {self.cid};")
181178
draws = len(data)
182179

183180
# Safety checks
@@ -201,20 +198,20 @@ def _get_rows( # pylint: disable=W0221
201198
arr[:] = buffer
202199
return arr
203200
# Otherwise (identical shapes) we can collapse into one ndarray
204-
return numpy.asarray(buffer, dtype=dtype)
201+
return numpy.asarray(buffer, dtype=dtype)[slc]
205202

206-
def get_draws(self, var_name: str) -> numpy.ndarray:
203+
def get_draws(self, var_name: str, slc: slice = slice(None)) -> numpy.ndarray:
207204
var = self.variables[var_name]
208205
nshape = var.shape if not var.undefined_ndim else None
209-
return self._get_rows(var_name, nshape, var.dtype)
206+
return self._get_rows(var_name, nshape, var.dtype, slc)
210207

211208
def get_draws_at(self, idx: int, var_names: Sequence[str]) -> Dict[str, numpy.ndarray]:
212209
return self._get_row_at(idx, var_names)
213210

214-
def get_stats(self, stat_name: str) -> numpy.ndarray:
211+
def get_stats(self, stat_name: str, slc: slice = slice(None)) -> numpy.ndarray:
215212
var = self.sample_stats[stat_name]
216213
nshape = var.shape if not var.undefined_ndim else None
217-
return self._get_rows(f"__stat_{stat_name}", nshape, var.dtype)
214+
return self._get_rows(f"__stat_{stat_name}", nshape, var.dtype, slc)
218215

219216
def get_stats_at(self, idx: int, stat_names: Sequence[str]) -> Dict[str, numpy.ndarray]:
220217
stats = self._get_row_at(idx, [f"__stat_{sname}" for sname in stat_names])

mcbackend/backends/numpy.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,14 @@ def append(
8787
def __len__(self) -> int:
8888
return self._draw_idx
8989

90-
def get_draws(self, var_name: str) -> numpy.ndarray:
91-
return self._samples[var_name][: self._draw_idx]
90+
def get_draws(self, var_name: str, slc: slice = slice(None)) -> numpy.ndarray:
91+
return self._samples[var_name][: self._draw_idx][slc]
9292

9393
def get_draws_at(self, idx: int, var_names: Sequence[str]) -> Dict[str, numpy.ndarray]:
9494
return {vn: numpy.asarray(self._samples[vn][idx]) for vn in var_names}
9595

96-
def get_stats(self, stat_name: str) -> numpy.ndarray:
97-
return self._stats[stat_name][: self._draw_idx]
96+
def get_stats(self, stat_name: str, slc: slice = slice(None)) -> numpy.ndarray:
97+
return self._stats[stat_name][: self._draw_idx][slc]
9898

9999
def get_stats_at(self, idx: int, stat_names: Sequence[str]) -> Dict[str, numpy.ndarray]:
100100
return {sn: numpy.asarray(self._stats[sn][idx]) for sn in stat_names}

mcbackend/core.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,30 @@ def append(
7070
"""
7171
raise NotImplementedError()
7272

73-
def get_draws(self, var_name: str) -> numpy.ndarray:
74-
"""Retrieve all draws of a variable from an MCMC chain."""
73+
def get_draws(self, var_name: str, slc: slice = slice(None)) -> numpy.ndarray:
74+
"""Retrieve draws of a variable from an MCMC chain.
75+
76+
Parameters
77+
----------
78+
var_name : str
79+
Name of the variable.
80+
slc : slice, optional
81+
Optional ``slice`` object to retrieve only a subset of elements.
82+
Passing this can be more performant than slicing the returned value.
83+
"""
7584
raise NotImplementedError()
7685

77-
def get_stats(self, stat_name: str) -> numpy.ndarray:
78-
"""Retrieve all values of a sampler statistic."""
86+
def get_stats(self, stat_name: str, slc: slice = slice(None)) -> numpy.ndarray:
87+
"""Retrieve values of a sampler statistic.
88+
89+
Parameters
90+
----------
91+
stat_name : str
92+
Name of the stats variable.
93+
slc : slice, optional
94+
Optional ``slice`` object to retrieve only a subset of elements.
95+
Passing this can be more performant than slicing the returned value.
96+
"""
7997
raise NotImplementedError()
8098

8199
def get_draws_at(self, idx: int, var_names: Sequence[str]) -> Dict[str, numpy.ndarray]:

mcbackend/test_utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,47 @@ def test__append_get_with_changelings(self, with_stats):
178178
numpy.testing.assert_array_equal(act, exp)
179179
pass
180180

181+
@pytest.mark.parametrize(
182+
"slc",
183+
[
184+
None,
185+
slice(None, None, None),
186+
slice(2, None, None),
187+
slice(2, 10, None),
188+
slice(2, 15, 3),
189+
slice(-8, None, None),
190+
slice(-8, -2, 2),
191+
slice(-50, -2, 2),
192+
slice(15, 10),
193+
],
194+
)
195+
def test__get_slicing(self, slc: slice):
196+
rmeta = make_runmeta(
197+
variables=[Variable("A", "uint8")],
198+
sample_stats=[Variable("B", "uint8")],
199+
data=[],
200+
)
201+
run = self.backend.init_run(rmeta)
202+
chain = run.init_chain(0)
203+
204+
# Generate draws and add them to the chain
205+
N = 20
206+
draws = [dict(A=n) for n in range(N)]
207+
stats = [dict(B=n) for n in range(N)]
208+
for d, s in zip(draws, stats):
209+
chain.append(d, s)
210+
assert len(chain) == N
211+
212+
# slc=None in this test means "don't pass it".
213+
# The implementations should default to slc=slice(None, None, None).
214+
expected = numpy.arange(N, dtype="uint8")[slc or slice(None, None, None)]
215+
kwargs = dict(slc=slc) if slc is not None else {}
216+
act_draws = chain.get_draws("A", **kwargs)
217+
act_stats = chain.get_stats("B", **kwargs)
218+
numpy.testing.assert_array_equal(act_draws, expected)
219+
numpy.testing.assert_array_equal(act_stats, expected)
220+
pass
221+
181222
def test__get_chains(self):
182223
rmeta = make_runmeta()
183224
run = self.backend.init_run(rmeta)

0 commit comments

Comments
 (0)