Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion bluesky-tiled-plugins/bluesky_tiled_plugins/consolidators.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ class DataSource:
management: Management = Management.writable


@dataclasses.dataclass
class Patch:
shape: tuple[int, ...]
offset: tuple[int, ...]


class ConsolidatorBase:
"""Consolidator of StreamDatums

Expand Down Expand Up @@ -249,10 +255,15 @@ def consume_stream_datum(self, doc: StreamDatum):
- Update the list of assets, including their uris, if necessary
- Update shape and chunks
"""
old_shape = self.shape # Adding new rows updates self.shape
self._num_rows += doc["indices"]["stop"] - doc["indices"]["start"]
new_seqnums = range(doc["seq_nums"]["start"], doc["seq_nums"]["stop"])
new_indices = range(doc["indices"]["start"], doc["indices"]["stop"])
self._seqnums_to_indices_map.update(dict(zip(new_seqnums, new_indices)))
return Patch(
offset=(old_shape[0], *[0 for _ in self.shape[1:]]),
shape=(self.shape[0] - old_shape[0], *self.shape[1:]),
)

def get_data_source(self) -> DataSource:
"""Return a DataSource object reflecting the current state of the streamed dataset.
Expand Down Expand Up @@ -475,7 +486,7 @@ def consume_stream_datum(self, doc: StreamDatum):
self.assets.append(new_asset)
self.data_uris.append(new_datum_uri)

super().consume_stream_datum(doc)
return super().consume_stream_datum(doc)


class TIFFConsolidator(MultipartRelatedConsolidator):
Expand Down
16 changes: 12 additions & 4 deletions bluesky-tiled-plugins/bluesky_tiled_plugins/tiled_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from tiled.structures.core import Spec
from tiled.utils import safe_json_dump

from .consolidators import ConsolidatorBase, DataSource, StructureFamily, consolidator_factory
from .consolidators import ConsolidatorBase, DataSource, Patch, StructureFamily, consolidator_factory

# Aggregate the Event table rows and StreamDatums in batches before writing to Tiled
BATCH_SIZE = 10000
Expand Down Expand Up @@ -553,16 +553,24 @@ def _write_external_data(self, doc: StreamDatum):

sres_uid, desc_uid = doc["stream_resource"], doc["descriptor"]
sres_node, consolidator = self.get_sres_node(sres_uid, desc_uid)
consolidator.consume_stream_datum(doc)
self._update_data_source_for_node(sres_node, consolidator.get_data_source())
patch = consolidator.consume_stream_datum(doc)
self._update_data_source_for_node(sres_node, consolidator.get_data_source(), patch)

def _update_data_source_for_node(self, node: BaseClient, data_source: DataSource):
def _update_data_source_for_node(
self, node: BaseClient, data_source: DataSource, patch: Optional[Patch] = None
):
"""Update StreamResource node in Tiled"""
data_source.id = node.data_sources()[0].id # ID of the existing DataSource record
handle_error(
node.context.http_client.put(
node.uri.replace("/metadata/", "/data_source/", 1),
content=safe_json_dump({"data_source": data_source}),
params={
"patch_shape": ",".join(map(str, patch.shape)),
"patch_offset": ",".join(map(str, patch.offset)),
}
if patch
else None,
)
).json()

Expand Down
51 changes: 51 additions & 0 deletions bluesky-tiled-plugins/tests/test_tiled_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Iterator
from pathlib import Path
from typing import Optional, Union, cast
from urllib.parse import parse_qs, urlparse

import bluesky.plan_stubs as bps
import bluesky.plans as bp
Expand Down Expand Up @@ -381,6 +382,56 @@ def test_with_correct_sample_runs(client, batch_size, external_assets_folder, fn
assert stream.read() is not None


@pytest.mark.parametrize(
"batch_size, expected_patch_shapes, expected_patch_offsets",
[(1, (1, 1, 1), (0, 1, 2)), (2, (2, 1), (0, 2)), (5, (3,), (0,))],
)
def test_data_source_patching(
client, batch_size, expected_patch_shapes, expected_patch_offsets, external_assets_folder
):
tw = TiledWriter(client, batch_size=batch_size)

with record_history() as history:
for item in render_templated_documents("external_assets.json", external_assets_folder):
tw(**item)

def parse_data_source_uri(uri: str):
"""Given a full data_source URL, extract:
- data_key (e.g. "det-key1")
- decoded query parameters as tuples of ints

Returns:
(data_key, params_dict)
"""

# data_key is the last component of the path
data_key = urlparse(uri).path.rstrip("/").split("/")[-1]

# parse query parameters and convert comma-separated values to tuples of ints
params = {}
for k, v in parse_qs(urlparse(uri).query).items():
params[k] = tuple(map(int, v[0].split(","))) # parse_qs gives lists

return data_key, params

put_uri_params = [
parse_data_source_uri(str(req.url))
for req in history.requests
if req.method == "PUT" and "/data_source" in req.url.path
]

# Check that each data key received the expected number of updates
assert len(put_uri_params) == 3 * len(expected_patch_shapes) # 3 data keys in the example
for data_key in {"det-key1", "det-key2", "det-key3"}:
assert len([uri for dk, uri in put_uri_params if dk == data_key]) == len(expected_patch_shapes)

# Check that the patch sizes and offsets (leftmost dimensions) match expectations
actual_patch_sizes = tuple(params["patch_shape"][0] for dk, params in put_uri_params if dk == data_key)
actual_patch_offsets = tuple(params["patch_offset"][0] for dk, params in put_uri_params if dk == data_key)
assert actual_patch_sizes == expected_patch_shapes
assert actual_patch_offsets == expected_patch_offsets


@pytest.mark.parametrize("error_type", ["shape", "chunks", "dtype"])
@pytest.mark.parametrize("validate", [True, False])
def test_validate_external_data(client, external_assets_folder, error_type, validate):
Expand Down