diff --git a/src/zarr/array.py b/src/zarr/array.py index 1cc4c8ccff..e366321b15 100644 --- a/src/zarr/array.py +++ b/src/zarr/array.py @@ -21,7 +21,7 @@ from zarr.abc.store import set_or_delete from zarr.attributes import Attributes from zarr.buffer import BufferPrototype, NDArrayLike, NDBuffer, default_buffer_prototype -from zarr.chunk_grids import RegularChunkGrid +from zarr.chunk_grids import RegularChunkGrid, _guess_chunks from zarr.chunk_key_encodings import ChunkKeyEncoding, DefaultChunkKeyEncoding, V2ChunkKeyEncoding from zarr.codecs import BytesCodec from zarr.codecs._v2 import V2Compressor, V2Filters @@ -62,6 +62,9 @@ ) from zarr.metadata import ArrayMetadata, ArrayV2Metadata, ArrayV3Metadata from zarr.store import StoreLike, StorePath, make_store_path +from zarr.store.core import ( + ensure_no_existing_node, +) from zarr.sync import sync @@ -137,12 +140,13 @@ async def create( compressor: dict[str, JSON] | None = None, # runtime exists_ok: bool = False, + data: npt.ArrayLike | None = None, ) -> AsyncArray: store_path = make_store_path(store) if chunk_shape is None: if chunks is None: - raise ValueError("Either chunk_shape or chunks needs to be provided.") + chunk_shape = chunks = _guess_chunks(shape=shape, typesize=np.dtype(dtype).itemsize) chunk_shape = chunks elif chunks is not None: raise ValueError("Only one of chunk_shape or chunks must be provided.") @@ -164,7 +168,7 @@ async def create( raise ValueError( "compressor cannot be used for arrays with version 3. Use bytes-to-bytes codecs instead." ) - return await cls._create_v3( + result = await cls._create_v3( store_path, shape=shape, dtype=dtype, @@ -187,7 +191,7 @@ async def create( ) if dimension_names is not None: raise ValueError("dimension_names cannot be used for arrays with version 2.") - return await cls._create_v2( + result = await cls._create_v2( store_path, shape=shape, dtype=dtype, @@ -203,6 +207,12 @@ async def create( else: raise ValueError(f"Insupported zarr_format. Got: {zarr_format}") + if data is not None: + # insert user-provided data + await result.setitem(..., data) + + return result + @classmethod async def _create_v3( cls, @@ -224,7 +234,7 @@ async def _create_v3( exists_ok: bool = False, ) -> AsyncArray: if not exists_ok: - assert not await (store_path / ZARR_JSON).exists() + await ensure_no_existing_node(store_path, zarr_format=3) codecs = list(codecs) if codecs is not None else [BytesCodec()] @@ -280,8 +290,7 @@ async def _create_v2( import numcodecs if not exists_ok: - assert not await (store_path / ZARRAY_JSON).exists() - + await ensure_no_existing_node(store_path, zarr_format=2) if order is None: order = "C" diff --git a/src/zarr/chunk_grids.py b/src/zarr/chunk_grids.py index 941f799849..a92c894e72 100644 --- a/src/zarr/chunk_grids.py +++ b/src/zarr/chunk_grids.py @@ -1,6 +1,7 @@ from __future__ import annotations import itertools +import math import operator from abc import abstractmethod from collections.abc import Iterator @@ -8,6 +9,8 @@ from functools import reduce from typing import TYPE_CHECKING +import numpy as np + from zarr.abc.metadata import Metadata from zarr.common import ( JSON, @@ -22,6 +25,75 @@ from typing_extensions import Self +def _guess_chunks( + shape: ChunkCoords, + typesize: int, + *, + increment_bytes: int = 256 * 1024, + min_bytes: int = 128 * 1024, + max_bytes: int = 64 * 1024 * 1024, +) -> ChunkCoords: + """ + Iteratively guess an appropriate chunk layout for an array, given its shape and + the size of each element in bytes, and size constraints expressed in bytes. This logic is + adapted from h5py. + + Parameters + ---------- + shape: ChunkCoords + The chunk shape. + typesize: int + The size, in bytes, of each element of the chunk. + increment_bytes: int = 256 * 1024 + The number of bytes used to increment or decrement the target chunk size in bytes. + min_bytes: int = 128 * 1024 + The soft lower bound on the final chunk size in bytes. + max_bytes: int = 64 * 1024 * 1024 + The hard upper bound on the final chunk size in bytes. + + Returns + ------- + ChunkCoords + + """ + + ndims = len(shape) + # require chunks to have non-zero length for all dimensions + chunks = np.maximum(np.array(shape, dtype="=f8"), 1) + + # Determine the optimal chunk size in bytes using a PyTables expression. + # This is kept as a float. + dset_size = np.prod(chunks) * typesize + target_size = increment_bytes * (2 ** np.log10(dset_size / (1024.0 * 1024))) + + if target_size > max_bytes: + target_size = max_bytes + elif target_size < min_bytes: + target_size = min_bytes + + idx = 0 + while True: + # Repeatedly loop over the axes, dividing them by 2. Stop when: + # 1a. We're smaller than the target chunk size, OR + # 1b. We're within 50% of the target chunk size, AND + # 2. The chunk is smaller than the maximum chunk size + + chunk_bytes = np.prod(chunks) * typesize + + if ( + chunk_bytes < target_size or abs(chunk_bytes - target_size) / target_size < 0.5 + ) and chunk_bytes < max_bytes: + break + + if np.prod(chunks) == 1: + break # Element size larger than max_bytes + + chunks[idx % ndims] = math.ceil(chunks[idx % ndims] / 2.0) + idx += 1 + + return tuple(int(x) for x in chunks) + + @dataclass(frozen=True) class ChunkGrid(Metadata): @classmethod diff --git a/src/zarr/errors.py b/src/zarr/errors.py new file mode 100644 index 0000000000..140229b2eb --- /dev/null +++ b/src/zarr/errors.py @@ -0,0 +1,25 @@ +from typing import Any + + +class _BaseZarrError(ValueError): + _msg = "" + + def __init__(self, *args: Any) -> None: + super().__init__(self._msg.format(*args)) + + +class ContainsGroupError(_BaseZarrError): + _msg = "A group exists in store {0!r} at path {1!r}." + + +class ContainsArrayError(_BaseZarrError): + _msg = "An array exists in store {0!r} at path {1!r}." + + +class ContainsArrayAndGroupError(_BaseZarrError): + _msg = ( + "Array and group metadata documents (.zarray and .zgroup) were both found in store " + "{0!r} at path {1!r}." + "Only one of these files may be present in a given directory / prefix. " + "Remove the .zarray file, or the .zgroup file, or both." + ) diff --git a/src/zarr/group.py b/src/zarr/group.py index a42d52d968..5361eb1345 100644 --- a/src/zarr/group.py +++ b/src/zarr/group.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Literal, cast, overload import numpy.typing as npt +from typing_extensions import deprecated from zarr.abc.codec import Codec from zarr.abc.metadata import Metadata @@ -27,6 +28,7 @@ ) from zarr.config import config from zarr.store import StoreLike, StorePath, make_store_path +from zarr.store.core import ensure_no_existing_node from zarr.sync import SyncMixin, sync if TYPE_CHECKING: @@ -129,10 +131,7 @@ async def create( ) -> AsyncGroup: store_path = make_store_path(store) if not exists_ok: - if zarr_format == 3: - assert not await (store_path / ZARR_JSON).exists() - elif zarr_format == 2: - assert not await (store_path / ZGROUP_JSON).exists() + await ensure_no_existing_node(store_path, zarr_format=zarr_format) attributes = attributes or {} group = cls( metadata=GroupMetadata(attributes=attributes, zarr_format=zarr_format), @@ -347,7 +346,49 @@ async def create_array( compressor: dict[str, JSON] | None = None, # runtime exists_ok: bool = False, + data: npt.ArrayLike | None = None, ) -> AsyncArray: + """ + Create a Zarr array within this AsyncGroup. + This method lightly wraps AsyncArray.create. + + Parameters + ---------- + path: str + The name of the array. + shape: tuple[int, ...] + The shape of the array. + dtype: np.DtypeLike = float64 + The data type of the array. + chunk_shape: tuple[int, ...] | None = None + The shape of the chunks of the array. V3 only. + chunk_key_encoding: ChunkKeyEncoding | tuple[Literal["default"], Literal[".", "/"]] | tuple[Literal["v2"], Literal[".", "/"]] | None = None + A specification of how the chunk keys are represented in storage. + codecs: Iterable[Codec | dict[str, JSON]] | None = None + An iterable of Codec or dict serializations thereof. The elements of + this collection specify the transformation from array values to stored bytes. + dimension_names: Iterable[str] | None = None + The names of the dimensions of the array. V3 only. + chunks: ChunkCoords | None = None + The shape of the chunks of the array. V2 only. + dimension_separator: Literal[".", "/"] | None = None + The delimiter used for the chunk keys. + order: Literal["C", "F"] | None = None + The memory order of the array. + filters: list[dict[str, JSON]] | None = None + Filters for the array. + compressor: dict[str, JSON] | None = None + The compressor for the array. + exists_ok: bool = False + If True, a pre-existing array or group at the path of this array will + be overwritten. If False, the presence of a pre-existing array or group is + an error. + + Returns + ------- + AsyncArray + + """ return await AsyncArray.create( self.store_path / path, shape=shape, @@ -365,6 +406,7 @@ async def create_array( compressor=compressor, exists_ok=exists_ok, zarr_format=self.metadata.zarr_format, + data=data, ) async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup: @@ -407,6 +449,7 @@ async def members(self) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], N # would be nice to make these special keys accessible programmatically, # and scoped to specific zarr versions _skip_keys = ("zarr.json", ".zgroup", ".zattrs") + async for key in self.store_path.store.list_dir(self.store_path.path): if key in _skip_keys: continue @@ -619,8 +662,99 @@ def tree(self, expand: bool = False, level: int | None = None) -> Any: def create_group(self, name: str, **kwargs: Any) -> Group: return Group(self._sync(self._async_group.create_group(name, **kwargs))) - def create_array(self, name: str, **kwargs: Any) -> Array: - return Array(self._sync(self._async_group.create_array(name, **kwargs))) + def create_array( + self, + name: str, + *, + shape: ChunkCoords, + dtype: npt.DTypeLike = "float64", + fill_value: Any | None = None, + attributes: dict[str, JSON] | None = None, + # v3 only + chunk_shape: ChunkCoords | None = None, + chunk_key_encoding: ( + ChunkKeyEncoding + | tuple[Literal["default"], Literal[".", "/"]] + | tuple[Literal["v2"], Literal[".", "/"]] + | None + ) = None, + codecs: Iterable[Codec | dict[str, JSON]] | None = None, + dimension_names: Iterable[str] | None = None, + # v2 only + chunks: ChunkCoords | None = None, + dimension_separator: Literal[".", "/"] | None = None, + order: Literal["C", "F"] | None = None, + filters: list[dict[str, JSON]] | None = None, + compressor: dict[str, JSON] | None = None, + # runtime + exists_ok: bool = False, + data: npt.ArrayLike | None = None, + ) -> Array: + """ + Create a zarr array within this AsyncGroup. + This method lightly wraps AsyncArray.create. + + Parameters + ---------- + name: str + The name of the array. + shape: tuple[int, ...] + The shape of the array. + dtype: np.DtypeLike = float64 + The data type of the array. + chunk_shape: tuple[int, ...] | None = None + The shape of the chunks of the array. V3 only. + chunk_key_encoding: ChunkKeyEncoding | tuple[Literal["default"], Literal[".", "/"]] | tuple[Literal["v2"], Literal[".", "/"]] | None = None + A specification of how the chunk keys are represented in storage. + codecs: Iterable[Codec | dict[str, JSON]] | None = None + An iterable of Codec or dict serializations thereof. The elements of this collection + specify the transformation from array values to stored bytes. + dimension_names: Iterable[str] | None = None + The names of the dimensions of the array. V3 only. + chunks: ChunkCoords | None = None + The shape of the chunks of the array. V2 only. + dimension_separator: Literal[".", "/"] | None = None + The delimiter used for the chunk keys. + order: Literal["C", "F"] | None = None + The memory order of the array. + filters: list[dict[str, JSON]] | None = None + Filters for the array. + compressor: dict[str, JSON] | None = None + The compressor for the array. + exists_ok: bool = False + If True, a pre-existing array or group at the path of this array will + be overwritten. If False, the presence of a pre-existing array or group is + an error. + data: npt.ArrayLike | None = None + Array data to initialize the array with. + + Returns + ------- + Array + + """ + return Array( + self._sync( + self._async_group.create_array( + path=name, + shape=shape, + dtype=dtype, + fill_value=fill_value, + attributes=attributes, + chunk_shape=chunk_shape, + chunk_key_encoding=chunk_key_encoding, + codecs=codecs, + dimension_names=dimension_names, + chunks=chunks, + dimension_separator=dimension_separator, + order=order, + filters=filters, + compressor=compressor, + exists_ok=exists_ok, + data=data, + ) + ) + ) def empty(self, **kwargs: Any) -> Array: return Array(self._sync(self._async_group.empty(**kwargs))) @@ -648,3 +782,99 @@ def full_like(self, prototype: AsyncArray, **kwargs: Any) -> Array: def move(self, source: str, dest: str) -> None: return self._sync(self._async_group.move(source, dest)) + + @deprecated("Use Group.create_array instead.") + def array( + self, + name: str, + *, + shape: ChunkCoords, + dtype: npt.DTypeLike = "float64", + fill_value: Any | None = None, + attributes: dict[str, JSON] | None = None, + # v3 only + chunk_shape: ChunkCoords | None = None, + chunk_key_encoding: ( + ChunkKeyEncoding + | tuple[Literal["default"], Literal[".", "/"]] + | tuple[Literal["v2"], Literal[".", "/"]] + | None + ) = None, + codecs: Iterable[Codec | dict[str, JSON]] | None = None, + dimension_names: Iterable[str] | None = None, + # v2 only + chunks: ChunkCoords | None = None, + dimension_separator: Literal[".", "/"] | None = None, + order: Literal["C", "F"] | None = None, + filters: list[dict[str, JSON]] | None = None, + compressor: dict[str, JSON] | None = None, + # runtime + exists_ok: bool = False, + data: npt.ArrayLike | None = None, + ) -> Array: + """ + Create a zarr array within this AsyncGroup. + This method lightly wraps `AsyncArray.create`. + + Parameters + ---------- + name: str + The name of the array. + shape: tuple[int, ...] + The shape of the array. + dtype: np.DtypeLike = float64 + The data type of the array. + chunk_shape: tuple[int, ...] | None = None + The shape of the chunks of the array. V3 only. + chunk_key_encoding: ChunkKeyEncoding | tuple[Literal["default"], Literal[".", "/"]] | tuple[Literal["v2"], Literal[".", "/"]] | None = None + A specification of how the chunk keys are represented in storage. + codecs: Iterable[Codec | dict[str, JSON]] | None = None + An iterable of Codec or dict serializations thereof. The elements of + this collection specify the transformation from array values to stored bytes. + dimension_names: Iterable[str] | None = None + The names of the dimensions of the array. V3 only. + chunks: ChunkCoords | None = None + The shape of the chunks of the array. V2 only. + dimension_separator: Literal[".", "/"] | None = None + The delimiter used for the chunk keys. + order: Literal["C", "F"] | None = None + The memory order of the array. + filters: list[dict[str, JSON]] | None = None + Filters for the array. + compressor: dict[str, JSON] | None = None + The compressor for the array. + exists_ok: bool = False + If True, a pre-existing array or group at the path of this array will + be overwritten. If False, the presence of a pre-existing array or group is + an error. + data: npt.ArrayLike | None = None + Array data to initialize the array with. + + Returns + ------- + + Array + + """ + return Array( + self._sync( + self._async_group.create_array( + path=name, + shape=shape, + dtype=dtype, + fill_value=fill_value, + attributes=attributes, + chunk_shape=chunk_shape, + chunk_key_encoding=chunk_key_encoding, + codecs=codecs, + dimension_names=dimension_names, + chunks=chunks, + dimension_separator=dimension_separator, + order=order, + filters=filters, + compressor=compressor, + exists_ok=exists_ok, + data=data, + ) + ) + ) diff --git a/src/zarr/store/core.py b/src/zarr/store/core.py index caa30d6997..85f85aabde 100644 --- a/src/zarr/store/core.py +++ b/src/zarr/store/core.py @@ -1,11 +1,13 @@ from __future__ import annotations +import json from pathlib import Path -from typing import Any +from typing import Any, Literal from zarr.abc.store import Store from zarr.buffer import Buffer, BufferPrototype, default_buffer_prototype -from zarr.common import OpenMode +from zarr.common import ZARR_JSON, ZARRAY_JSON, ZGROUP_JSON, OpenMode, ZarrFormat +from zarr.errors import ContainsArrayAndGroupError, ContainsArrayError, ContainsGroupError from zarr.store.local import LocalStore from zarr.store.memory import MemoryStore @@ -84,3 +86,170 @@ def make_store_path(store_like: StoreLike | None, *, mode: OpenMode | None = Non elif isinstance(store_like, str): return StorePath(LocalStore(Path(store_like), mode=mode or "r")) raise TypeError + + +async def ensure_no_existing_node(store_path: StorePath, zarr_format: ZarrFormat) -> None: + """ + Check if a store_path is safe for array / group creation. + Returns `None` or raises an exception. + + Parameters + ---------- + store_path: StorePath + The storage location to check. + zarr_format: ZarrFormat + The Zarr format to check. + + Raises + ------ + ContainsArrayError, ContainsGroupError, ContainsArrayAndGroupError + """ + if zarr_format == 2: + extant_node = await _contains_node_v2(store_path) + elif zarr_format == 3: + extant_node = await _contains_node_v3(store_path) + + if extant_node == "array": + raise ContainsArrayError(store_path.store, store_path.path) + elif extant_node == "group": + raise ContainsGroupError(store_path.store, store_path.path) + elif extant_node == "nothing": + return + msg = f"Invalid value for extant_node: {extant_node}" # type: ignore[unreachable] + raise ValueError(msg) + + +async def _contains_node_v3(store_path: StorePath) -> Literal["array", "group", "nothing"]: + """ + Check if a store_path contains nothing, an array, or a group. This function + returns the string "array", "group", or "nothing" to denote containing an array, a group, or + nothing. + + Parameters + ---------- + store_path: StorePath + The location in storage to check. + + Returns + ------- + Literal["array", "group", "nothing"] + A string representing the zarr node found at store_path. + """ + result: Literal["array", "group", "nothing"] = "nothing" + extant_meta_bytes = await (store_path / ZARR_JSON).get() + # if no metadata document could be loaded, then we just return "nothing" + if extant_meta_bytes is not None: + try: + extant_meta_json = json.loads(extant_meta_bytes.to_bytes()) + # avoid constructing a full metadata document here in the name of speed. + if extant_meta_json["node_type"] == "array": + result = "array" + elif extant_meta_json["node_type"] == "group": + result = "group" + except (KeyError, json.JSONDecodeError): + # either of these errors is consistent with no array or group present. + pass + return result + + +async def _contains_node_v2(store_path: StorePath) -> Literal["array", "group", "nothing"]: + """ + Check if a store_path contains nothing, an array, a group, or both. If both an array and a + group are detected, a `ContainsArrayAndGroup` exception is raised. Otherwise, this function + returns the string "array", "group", or "nothing" to denote containing an array, a group, or + nothing. + + Parameters + ---------- + store_path: StorePath + The location in storage to check. + + Returns + ------- + Literal["array", "group", "nothing"] + A string representing the zarr node found at store_path. + """ + _array = await contains_array(store_path=store_path, zarr_format=2) + _group = await contains_group(store_path=store_path, zarr_format=2) + + if _array and _group: + raise ContainsArrayAndGroupError(store_path.store, store_path.path) + elif _array: + return "array" + elif _group: + return "group" + else: + return "nothing" + + +async def contains_array(store_path: StorePath, zarr_format: ZarrFormat) -> bool: + """ + Check if an array exists at a given StorePath. + + Parameters + ---------- + store_path: StorePath + The StorePath to check for an existing group. + zarr_format: + The zarr format to check for. + + Returns + ------- + bool + True if the StorePath contains a group, False otherwise. + + """ + if zarr_format == 3: + extant_meta_bytes = await (store_path / ZARR_JSON).get() + if extant_meta_bytes is None: + return False + else: + try: + extant_meta_json = json.loads(extant_meta_bytes.to_bytes()) + # we avoid constructing a full metadata document here in the name of speed. + if extant_meta_json["node_type"] == "array": + return True + except (ValueError, KeyError): + return False + elif zarr_format == 2: + result = await (store_path / ZARRAY_JSON).exists() + return result + msg = f"Invalid zarr_format provided. Got {zarr_format}, expected 2 or 3" + raise ValueError(msg) + + +async def contains_group(store_path: StorePath, zarr_format: ZarrFormat) -> bool: + """ + Check if a group exists at a given StorePath. + + Parameters + ---------- + + store_path: StorePath + The StorePath to check for an existing group. + zarr_format: + The zarr format to check for. + + Returns + ------- + + bool + True if the StorePath contains a group, False otherwise + + """ + if zarr_format == 3: + extant_meta_bytes = await (store_path / ZARR_JSON).get() + if extant_meta_bytes is None: + return False + else: + try: + extant_meta_json = json.loads(extant_meta_bytes.to_bytes()) + # we avoid constructing a full metadata document here in the name of speed. + result: bool = extant_meta_json["node_type"] == "group" + return result + except (ValueError, KeyError): + return False + elif zarr_format == 2: + return await (store_path / ZGROUP_JSON).exists() + msg = f"Invalid zarr_format provided. Got {zarr_format}, expected 2 or 3" # type: ignore[unreachable] + raise ValueError(msg) diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index 7b73330b6c..7f3c575719 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -95,8 +95,9 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: prefix = prefix[:-1] if prefix == "": - for key in self._store_dict: - yield key.split("/", maxsplit=1)[0] + keys_unique = set(k.split("/")[0] for k in self._store_dict.keys()) + for key in keys_unique: + yield key else: for key in self._store_dict: if key.startswith(prefix + "/") and key != prefix: diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 9c37ce0434..a4e154bbc9 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -183,8 +183,12 @@ async def test_list_dir(self, store: S) -> None: await store.set("foo/zarr.json", Buffer.from_bytes(b"bar")) await store.set("foo/c/1", Buffer.from_bytes(b"\x01")) - keys = [k async for k in store.list_dir("foo")] - assert set(keys) == set(["zarr.json", "c"]), keys + keys_expected = ["zarr.json", "c"] + keys_observed = [k async for k in store.list_dir("foo")] - keys = [k async for k in store.list_dir("foo/")] - assert set(keys) == set(["zarr.json", "c"]), keys + assert len(keys_observed) == len(keys_expected), keys_observed + assert set(keys_observed) == set(keys_expected), keys_observed + + keys_observed = [k async for k in store.list_dir("foo/")] + assert len(keys_expected) == len(keys_observed), keys_observed + assert set(keys_observed) == set(keys_expected), keys_observed diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index 08678f5989..9fd135ad5c 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -1,10 +1,63 @@ +from typing import Literal + import numpy as np import pytest from zarr.array import Array from zarr.common import ZarrFormat +from zarr.errors import ContainsArrayError, ContainsGroupError from zarr.group import Group from zarr.store import LocalStore, MemoryStore +from zarr.store.core import StorePath + + +@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) +@pytest.mark.parametrize("zarr_format", (2, 3)) +@pytest.mark.parametrize("exists_ok", [True, False]) +@pytest.mark.parametrize("extant_node", ["array", "group"]) +def test_array_creation_existing_node( + store: LocalStore | MemoryStore, + zarr_format: ZarrFormat, + exists_ok: bool, + extant_node: Literal["array", "group"], +) -> None: + """ + Check that an existing array or group is handled as expected during array creation. + """ + spath = StorePath(store) + group = Group.create(spath, zarr_format=zarr_format) + expected_exception: type[ContainsArrayError] | type[ContainsGroupError] + if extant_node == "array": + expected_exception = ContainsArrayError + _ = group.create_array("extant", shape=(10,), dtype="uint8") + elif extant_node == "group": + expected_exception = ContainsGroupError + _ = group.create_group("extant") + else: + raise AssertionError + + new_shape = (2, 2) + new_dtype = "float32" + + if exists_ok: + arr_new = Array.create( + spath / "extant", + shape=new_shape, + dtype=new_dtype, + exists_ok=exists_ok, + zarr_format=zarr_format, + ) + assert arr_new.shape == new_shape + assert arr_new.dtype == new_dtype + else: + with pytest.raises(expected_exception): + arr_new = Array.create( + spath / "extant", + shape=new_shape, + dtype=new_dtype, + exists_ok=exists_ok, + zarr_format=zarr_format, + ) @pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) diff --git a/tests/v3/test_chunk_grids.py b/tests/v3/test_chunk_grids.py new file mode 100644 index 0000000000..3cc6b64e57 --- /dev/null +++ b/tests/v3/test_chunk_grids.py @@ -0,0 +1,18 @@ +import numpy as np +import pytest + +from zarr.chunk_grids import _guess_chunks + + +@pytest.mark.parametrize( + "shape", ((0,), (0,) * 2, (1, 2, 0, 4, 5), (10, 0), (10,), (100,) * 3, (1000000,), (10000,) * 2) +) +@pytest.mark.parametrize("itemsize", (1, 2, 4)) +def test_guess_chunks(shape: tuple[int, ...], itemsize: int) -> None: + chunks = _guess_chunks(shape, itemsize) + chunk_size = np.prod(chunks) * itemsize + assert isinstance(chunks, tuple) + assert len(chunks) == len(shape) + assert chunk_size < (64 * 1024 * 1024) + # doesn't make any sense to allow chunks to have zero length dimension + assert all(0 < c <= max(s, 1) for c, s in zip(chunks, shape, strict=False)) diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index e11af748b3..f942eb6033 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -1,26 +1,77 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import Any, Literal, cast -from zarr.array import AsyncArray +import numpy as np +import pytest +from _pytest.compat import LEGACY_PATH + +from zarr.array import Array, AsyncArray from zarr.buffer import Buffer +from zarr.common import ZarrFormat +from zarr.errors import ContainsArrayError, ContainsGroupError +from zarr.group import AsyncGroup, Group, GroupMetadata +from zarr.store import LocalStore, MemoryStore, StorePath from zarr.store.core import make_store_path from zarr.sync import sync -if TYPE_CHECKING: - from zarr.common import ZarrFormat - from zarr.store import LocalStore, MemoryStore +from .conftest import parse_store -import numpy as np -import pytest -from zarr.group import AsyncGroup, Group, GroupMetadata -from zarr.store import StorePath +@pytest.fixture(params=["local", "memory"]) +def store(request: pytest.FixtureRequest, tmpdir: LEGACY_PATH) -> LocalStore | MemoryStore: + result = parse_store(request.param, str(tmpdir)) + if not isinstance(result, MemoryStore | LocalStore): + raise TypeError("Wrong store class returned by test fixture!") + return result -# todo: put RemoteStore in here -@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) -def test_group_children(store: MemoryStore | LocalStore) -> None: +@pytest.fixture(params=[True, False]) +def exists_ok(request: pytest.FixtureRequest) -> bool: + result = request.param + if not isinstance(result, bool): + raise TypeError("Wrong type returned by test fixture.") + return result + + +@pytest.fixture(params=[2, 3], ids=["zarr2", "zarr3"]) +def zarr_format(request: pytest.FixtureRequest) -> ZarrFormat: + result = request.param + if result not in (2, 3): + raise ValueError("Wrong value returned from test fixture.") + return cast(ZarrFormat, result) + + +def test_group_init(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: + """ + Test that initializing a group from an asyncgroup works. + """ + agroup = sync(AsyncGroup.create(store=store, zarr_format=zarr_format)) + group = Group(agroup) + assert group._async_group == agroup + + +def test_group_name_properties(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: + """ + Test basic properties of groups + """ + root = Group.create(store=store, zarr_format=zarr_format) + assert root.path == "" + assert root.name == "/" + assert root.basename == "" + + foo = root.create_group("foo") + assert foo.path == "foo" + assert foo.name == "/foo" + assert foo.basename == "foo" + + bar = root.create_group("foo/bar") + assert bar.path == "foo/bar" + assert bar.name == "/foo/bar" + assert bar.basename == "bar" + + +def test_group_members(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) -> None: """ Test that `Group.members` returns correct values, i.e. the arrays and groups (explicit and implicit) contained in that group. @@ -28,16 +79,16 @@ def test_group_children(store: MemoryStore | LocalStore) -> None: path = "group" agroup = AsyncGroup( - metadata=GroupMetadata(), + metadata=GroupMetadata(zarr_format=zarr_format), store_path=StorePath(store=store, path=path), ) group = Group(agroup) - members_expected = {} + members_expected: dict[str, Array | Group] = {} members_expected["subgroup"] = group.create_group("subgroup") # make a sub-sub-subgroup, to ensure that the children calculation doesn't go # too deep in the hierarchy - _ = members_expected["subgroup"].create_group("subsubgroup") + _ = members_expected["subgroup"].create_group("subsubgroup") # type: ignore members_expected["subarray"] = group.create_array( "subarray", shape=(100,), dtype="uint8", chunk_shape=(10,), exists_ok=True @@ -55,12 +106,15 @@ def test_group_children(store: MemoryStore | LocalStore) -> None: assert sorted(dict(members_observed)) == sorted(members_expected) -@pytest.mark.parametrize("store", (("local", "memory")), indirect=["store"]) -def test_group(store: MemoryStore | LocalStore) -> None: +def test_group(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) -> None: + """ + Test basic Group routines. + """ store_path = StorePath(store) - agroup = AsyncGroup(metadata=GroupMetadata(), store_path=store_path) + agroup = AsyncGroup(metadata=GroupMetadata(zarr_format=zarr_format), store_path=store_path) group = Group(agroup) assert agroup.metadata is group.metadata + assert agroup.store_path == group.store_path == store_path # create two groups foo = group.create_group("foo") @@ -94,29 +148,270 @@ def test_group(store: MemoryStore | LocalStore) -> None: assert dict(bar3.attrs) == {"baz": "qux", "name": "bar"} -@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) -@pytest.mark.parametrize("exists_ok", (True, False)) -def test_group_create(store: MemoryStore | LocalStore, exists_ok: bool) -> None: +def test_group_create( + store: MemoryStore | LocalStore, exists_ok: bool, zarr_format: ZarrFormat +) -> None: """ Test that `Group.create` works as expected. """ attributes = {"foo": 100} - group = Group.create(store, attributes=attributes, exists_ok=exists_ok) + group = Group.create(store, attributes=attributes, zarr_format=zarr_format, exists_ok=exists_ok) assert group.attrs == attributes if not exists_ok: - with pytest.raises(AssertionError): + with pytest.raises(ContainsGroupError): group = Group.create( - store, - attributes=attributes, - exists_ok=exists_ok, + store, attributes=attributes, exists_ok=exists_ok, zarr_format=zarr_format ) +def test_group_open( + store: MemoryStore | LocalStore, zarr_format: ZarrFormat, exists_ok: bool +) -> None: + """ + Test the `Group.open` method. + """ + spath = StorePath(store) + # attempt to open a group that does not exist + with pytest.raises(FileNotFoundError): + Group.open(store) + + # create the group + attrs = {"path": "foo"} + group_created = Group.create( + store, attributes=attrs, zarr_format=zarr_format, exists_ok=exists_ok + ) + assert group_created.attrs == attrs + assert group_created.metadata.zarr_format == zarr_format + assert group_created.store_path == spath + + # attempt to create a new group in place, to test exists_ok + new_attrs = {"path": "bar"} + if not exists_ok: + with pytest.raises(ContainsGroupError): + Group.create(store, attributes=attrs, zarr_format=zarr_format, exists_ok=exists_ok) + else: + group_created_again = Group.create( + store, attributes=new_attrs, zarr_format=zarr_format, exists_ok=exists_ok + ) + assert group_created_again.attrs == new_attrs + assert group_created_again.metadata.zarr_format == zarr_format + assert group_created_again.store_path == spath + + +def test_group_getitem(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) -> None: + """ + Test the `Group.__getitem__` method. + """ + + group = Group.create(store, zarr_format=zarr_format) + subgroup = group.create_group(name="subgroup") + subarray = group.create_array(name="subarray", shape=(10,), chunk_shape=(10,)) + + assert group["subgroup"] == subgroup + assert group["subarray"] == subarray + with pytest.raises(KeyError): + group["nope"] + + +def test_group_delitem(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) -> None: + """ + Test the `Group.__delitem__` method. + """ + + group = Group.create(store, zarr_format=zarr_format) + subgroup = group.create_group(name="subgroup") + subarray = group.create_array(name="subarray", shape=(10,), chunk_shape=(10,)) + + assert group["subgroup"] == subgroup + assert group["subarray"] == subarray + + del group["subgroup"] + with pytest.raises(KeyError): + group["subgroup"] + + del group["subarray"] + with pytest.raises(KeyError): + group["subarray"] + + +def test_group_iter(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) -> None: + """ + Test the `Group.__iter__` method. + """ + + group = Group.create(store, zarr_format=zarr_format) + with pytest.raises(NotImplementedError): + [x for x in group] # type: ignore + + +def test_group_len(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) -> None: + """ + Test the `Group.__len__` method. + """ + + group = Group.create(store, zarr_format=zarr_format) + with pytest.raises(NotImplementedError): + len(group) # type: ignore + + +def test_group_setitem(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) -> None: + """ + Test the `Group.__setitem__` method. + """ + group = Group.create(store, zarr_format=zarr_format) + with pytest.raises(NotImplementedError): + group["key"] = 10 + + +def test_group_contains(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) -> None: + """ + Test the `Group.__contains__` method + """ + group = Group.create(store, zarr_format=zarr_format) + assert "foo" not in group + _ = group.create_group(name="foo") + assert "foo" in group + + +def test_group_subgroups(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) -> None: + """ + Test the behavior of `Group` methods for accessing subgroups, namely `Group.group_keys` and `Group.groups` + """ + group = Group.create(store, zarr_format=zarr_format) + keys = ("foo", "bar") + subgroups_expected = tuple(group.create_group(k) for k in keys) + # create a sub-array as well + _ = group.create_array("array", shape=(10,)) + subgroups_observed = group.groups() + assert set(group.group_keys()) == set(keys) + assert len(subgroups_observed) == len(subgroups_expected) + assert all(a in subgroups_observed for a in subgroups_expected) + + +def test_group_subarrays(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) -> None: + """ + Test the behavior of `Group` methods for accessing subgroups, namely `Group.group_keys` and `Group.groups` + """ + group = Group.create(store, zarr_format=zarr_format) + keys = ("foo", "bar") + subarrays_expected = tuple(group.create_array(k, shape=(10,)) for k in keys) + # create a sub-group as well + _ = group.create_group("group") + subarrays_observed = group.arrays() + assert set(group.array_keys()) == set(keys) + assert len(subarrays_observed) == len(subarrays_expected) + assert all(a in subarrays_observed for a in subarrays_expected) + + +def test_group_update_attributes(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) -> None: + """ + Test the behavior of `Group.update_attributes` + """ + attrs = {"foo": 100} + group = Group.create(store, zarr_format=zarr_format, attributes=attrs) + assert group.attrs == attrs + new_attrs = {"bar": 100} + new_group = group.update_attributes(new_attrs) + assert new_group.attrs == new_attrs + + +async def test_group_update_attributes_async( + store: MemoryStore | LocalStore, zarr_format: ZarrFormat +) -> None: + """ + Test the behavior of `Group.update_attributes_async` + """ + attrs = {"foo": 100} + group = Group.create(store, zarr_format=zarr_format, attributes=attrs) + assert group.attrs == attrs + new_attrs = {"bar": 100} + new_group = await group.update_attributes_async(new_attrs) + assert new_group.attrs == new_attrs + + +@pytest.mark.parametrize("method", ["create_array", "array"]) +def test_group_create_array( + store: MemoryStore | LocalStore, + zarr_format: ZarrFormat, + exists_ok: bool, + method: Literal["create_array", "array"], +) -> None: + """ + Test `Group.create_array` + """ + group = Group.create(store, zarr_format=zarr_format) + shape = (10, 10) + dtype = "uint8" + data = np.arange(np.prod(shape)).reshape(shape).astype(dtype) + + if method == "create_array": + array = group.create_array(name="array", shape=shape, dtype=dtype, data=data) + elif method == "array": + array = group.array(name="array", shape=shape, dtype=dtype, data=data) + else: + raise AssertionError + + if not exists_ok: + if method == "create_array": + with pytest.raises(ContainsArrayError): + group.create_array(name="array", shape=shape, dtype=dtype, data=data) + elif method == "array": + with pytest.raises(ContainsArrayError): + group.array(name="array", shape=shape, dtype=dtype, data=data) + assert array.shape == shape + assert array.dtype == np.dtype(dtype) + assert np.array_equal(array[:], data) + + @pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) @pytest.mark.parametrize("zarr_format", (2, 3)) -@pytest.mark.parametrize("exists_ok", (True, False)) +@pytest.mark.parametrize("exists_ok", [True, False]) +@pytest.mark.parametrize("extant_node", ["array", "group"]) +def test_group_creation_existing_node( + store: LocalStore | MemoryStore, + zarr_format: ZarrFormat, + exists_ok: bool, + extant_node: Literal["array", "group"], +) -> None: + """ + Check that an existing array or group is handled as expected during group creation. + """ + spath = StorePath(store) + group = Group.create(spath, zarr_format=zarr_format) + expected_exception: type[ContainsArrayError] | type[ContainsGroupError] + attributes = {"old": True} + + if extant_node == "array": + expected_exception = ContainsArrayError + _ = group.create_array("extant", shape=(10,), dtype="uint8", attributes=attributes) + elif extant_node == "group": + expected_exception = ContainsGroupError + _ = group.create_group("extant", attributes=attributes) + else: + raise AssertionError + + new_attributes = {"new": True} + + if exists_ok: + node_new = Group.create( + spath / "extant", + attributes=new_attributes, + zarr_format=zarr_format, + exists_ok=exists_ok, + ) + assert node_new.attrs == new_attributes + else: + with pytest.raises(expected_exception): + node_new = Group.create( + spath / "extant", + attributes=new_attributes, + zarr_format=zarr_format, + exists_ok=exists_ok, + ) + + async def test_asyncgroup_create( store: MemoryStore | LocalStore, exists_ok: bool, @@ -125,6 +420,7 @@ async def test_asyncgroup_create( """ Test that `AsyncGroup.create` works as expected. """ + spath = StorePath(store=store) attributes = {"foo": 100} agroup = await AsyncGroup.create( store, @@ -137,17 +433,27 @@ async def test_asyncgroup_create( assert agroup.store_path == make_store_path(store) if not exists_ok: - with pytest.raises(AssertionError): + with pytest.raises(ContainsGroupError): agroup = await AsyncGroup.create( - store, + spath, + attributes=attributes, + exists_ok=exists_ok, + zarr_format=zarr_format, + ) + # create an array at our target path + collision_name = "foo" + _ = await AsyncArray.create( + spath / collision_name, shape=(10,), dtype="uint8", zarr_format=zarr_format + ) + with pytest.raises(ContainsArrayError): + _ = await AsyncGroup.create( + StorePath(store=store) / collision_name, attributes=attributes, exists_ok=exists_ok, zarr_format=zarr_format, ) -@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) -@pytest.mark.parametrize("zarr_format", (2, 3)) async def test_asyncgroup_attrs(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: attributes = {"foo": 100} agroup = await AsyncGroup.create(store, zarr_format=zarr_format, attributes=attributes) @@ -155,8 +461,6 @@ async def test_asyncgroup_attrs(store: LocalStore | MemoryStore, zarr_format: Za assert agroup.attrs == agroup.metadata.attributes == attributes -@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) -@pytest.mark.parametrize("zarr_format", (2, 3)) async def test_asyncgroup_info(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: agroup = await AsyncGroup.create( # noqa store, @@ -166,8 +470,6 @@ async def test_asyncgroup_info(store: LocalStore | MemoryStore, zarr_format: Zar # assert agroup.info == agroup.metadata.info -@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) -@pytest.mark.parametrize("zarr_format", (2, 3)) async def test_asyncgroup_open( store: LocalStore | MemoryStore, zarr_format: ZarrFormat, @@ -189,14 +491,12 @@ async def test_asyncgroup_open( assert group_w == group_r -@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) -@pytest.mark.parametrize("zarr_format", (2, 3)) async def test_asyncgroup_open_wrong_format( store: LocalStore | MemoryStore, zarr_format: ZarrFormat, ) -> None: _ = await AsyncGroup.create(store=store, exists_ok=False, zarr_format=zarr_format) - + zarr_format_wrong: ZarrFormat # try opening with the wrong zarr format if zarr_format == 3: zarr_format_wrong = 2 @@ -211,7 +511,6 @@ async def test_asyncgroup_open_wrong_format( # todo: replace the dict[str, Any] type with something a bit more specific # should this be async? -@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) @pytest.mark.parametrize( "data", ( @@ -234,8 +533,6 @@ def test_asyncgroup_from_dict(store: MemoryStore | LocalStore, data: dict[str, A # todo: replace this with a declarative API where we model a full hierarchy -@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) -@pytest.mark.parametrize("zarr_format", (2, 3)) async def test_asyncgroup_getitem(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: """ Create an `AsyncGroup`, then create members of that group, and ensure that we can access those @@ -258,11 +555,6 @@ async def test_asyncgroup_getitem(store: LocalStore | MemoryStore, zarr_format: await agroup.getitem("foo") -# todo: replace this with a declarative API where we model a full hierarchy - - -@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) -@pytest.mark.parametrize("zarr_format", (2, 3)) async def test_asyncgroup_delitem(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: agroup = await AsyncGroup.create(store=store, zarr_format=zarr_format) sub_array_path = "sub_array" @@ -292,8 +584,6 @@ async def test_asyncgroup_delitem(store: LocalStore | MemoryStore, zarr_format: raise AssertionError -@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) -@pytest.mark.parametrize("zarr_format", (2, 3)) async def test_asyncgroup_create_group( store: LocalStore | MemoryStore, zarr_format: ZarrFormat, @@ -310,11 +600,8 @@ async def test_asyncgroup_create_group( assert subnode.metadata.zarr_format == zarr_format -@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) -@pytest.mark.parametrize("zarr_format", (2, 3)) async def test_asyncgroup_create_array( - store: LocalStore | MemoryStore, - zarr_format: ZarrFormat, + store: LocalStore | MemoryStore, zarr_format: ZarrFormat, exists_ok: bool ) -> None: """ Test that the AsyncGroup.create_array method works correctly. We ensure that array properties @@ -323,6 +610,10 @@ async def test_asyncgroup_create_array( agroup = await AsyncGroup.create(store=store, zarr_format=zarr_format) + if not exists_ok: + with pytest.raises(ContainsGroupError): + agroup = await AsyncGroup.create(store=store, zarr_format=zarr_format) + shape = (10,) dtype = "uint8" chunk_shape = (4,) @@ -348,8 +639,6 @@ async def test_asyncgroup_create_array( assert subnode.metadata.zarr_format == zarr_format -@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) -@pytest.mark.parametrize("zarr_format", (2, 3)) async def test_asyncgroup_update_attributes( store: LocalStore | MemoryStore, zarr_format: ZarrFormat ) -> None: @@ -364,30 +653,3 @@ async def test_asyncgroup_update_attributes( agroup_new_attributes = await agroup.update_attributes(attributes_new) assert agroup_new_attributes.attrs == attributes_new - - -@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) -@pytest.mark.parametrize("zarr_format", (2, 3)) -def test_group_init(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: - agroup = sync(AsyncGroup.create(store=store, zarr_format=zarr_format)) - group = Group(agroup) - assert group._async_group == agroup - - -@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) -@pytest.mark.parametrize("zarr_format", (2, 3)) -def test_group_name_properties(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: - root = Group.create(store=store, zarr_format=zarr_format) - assert root.path == "" - assert root.name == "/" - assert root.basename == "" - - foo = root.create_group("foo") - assert foo.path == "foo" - assert foo.name == "/foo" - assert foo.basename == "foo" - - bar = root.create_group("foo/bar") - assert bar.path == "foo/bar" - assert bar.name == "/foo/bar" - assert bar.basename == "bar" diff --git a/tests/v3/test_store/test_memory.py b/tests/v3/test_store/test_memory.py index 96b8b19e2c..dd3cad7d7e 100644 --- a/tests/v3/test_store/test_memory.py +++ b/tests/v3/test_store/test_memory.py @@ -17,7 +17,9 @@ def get(self, store: MemoryStore, key: str) -> Buffer: return store._store_dict[key] @pytest.fixture(scope="function", params=[None, {}]) - def store_kwargs(self, request) -> dict[str, str | None | dict[str, Buffer]]: + def store_kwargs( + self, request: pytest.FixtureRequest + ) -> dict[str, str | None | dict[str, Buffer]]: return {"store_dict": request.param, "mode": "w"} @pytest.fixture(scope="function")