diff --git a/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py b/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py index ee28c53ec..c9a70268c 100755 --- a/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py +++ b/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py @@ -197,8 +197,8 @@ def sample(self, max_shots: int) -> AnonTaskStats: raise ValueError("predictions.dtype != np.uint8") if len(predictions.shape) != 2: raise ValueError("len(predictions.shape) != 2") - if predictions.shape[0] != num_shots: - raise ValueError("predictions.shape[0] != num_shots") + if predictions.shape[0] != num_shots - num_discards_1: + raise ValueError("predictions.shape[0] != num_shots - num_discards_1") if predictions.shape[1] < actual_obs.shape[1]: raise ValueError("predictions.shape[1] < actual_obs.shape[1]") if predictions.shape[1] > actual_obs.shape[1] + 1: diff --git a/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler_test.py b/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler_test.py index 413015f0d..2ef3f1a7e 100755 --- a/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler_test.py +++ b/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler_test.py @@ -1,9 +1,12 @@ import collections import numpy as np +import stim +from sinter._data import Task +from sinter._decoding._decoding_vacuous import VacuousDecoder from sinter._decoding._stim_then_decode_sampler import \ - classify_discards_and_errors + classify_discards_and_errors, _CompiledStimThenDecodeSampler def test_classify_discards_and_errors(): @@ -190,3 +193,23 @@ def test_classify_discards_and_errors(): num_obs=13, ) == (0, 1) assert counter == collections.Counter(["obs_mistake_mask=_________E___"]) + +def test_detector_post_selection(): + circuit = stim.Circuit(""" + X_ERROR(1) 0 + M 0 + DETECTOR rec[-1] + """) + sampler = _CompiledStimThenDecodeSampler( + decoder=VacuousDecoder(), + task = Task( + circuit=circuit, + detector_error_model=circuit.detector_error_model(), + postselection_mask=np.array([1], dtype=np.uint8), + ), + count_observable_error_combos=False, + count_detection_events=False, + tmp_dir=None + ) + result = sampler.sample(max_shots=1) + assert result.discards == 1 \ No newline at end of file