diff --git a/asv_bench/benchmarks/dataset_io.py b/asv_bench/benchmarks/dataset_io.py index 6c2e15c54e9..0af8084dd21 100644 --- a/asv_bench/benchmarks/dataset_io.py +++ b/asv_bench/benchmarks/dataset_io.py @@ -1,11 +1,14 @@ +from __future__ import annotations + import os +from dataclasses import dataclass import numpy as np import pandas as pd import xarray as xr -from . import _skip_slow, randint, randn, requires_dask +from . import _skip_slow, parameterized, randint, randn, requires_dask try: import dask @@ -16,6 +19,8 @@ os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" +_ENGINES = tuple(xr.backends.list_engines().keys() - {"store"}) + class IOSingleNetCDF: """ @@ -28,10 +33,6 @@ class IOSingleNetCDF: number = 5 def make_ds(self): - # TODO: Lazily skipped in CI as it is very demanding and slow. - # Improve times and remove errors. - _skip_slow() - # single Dataset self.ds = xr.Dataset() self.nt = 1000 @@ -95,6 +96,10 @@ def make_ds(self): class IOWriteSingleNetCDF3(IOSingleNetCDF): def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() + self.format = "NETCDF3_64BIT" self.make_ds() @@ -107,6 +112,9 @@ def time_write_dataset_scipy(self): class IOReadSingleNetCDF4(IOSingleNetCDF): def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() self.make_ds() @@ -128,6 +136,9 @@ def time_vectorized_indexing(self): class IOReadSingleNetCDF3(IOReadSingleNetCDF4): def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() self.make_ds() @@ -149,6 +160,9 @@ def time_vectorized_indexing(self): class IOReadSingleNetCDF4Dask(IOSingleNetCDF): def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() requires_dask() @@ -189,6 +203,9 @@ def time_load_dataset_netcdf4_with_time_chunks_multiprocessing(self): class IOReadSingleNetCDF3Dask(IOReadSingleNetCDF4Dask): def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() requires_dask() @@ -230,10 +247,6 @@ class IOMultipleNetCDF: number = 5 def make_ds(self, nfiles=10): - # TODO: Lazily skipped in CI as it is very demanding and slow. - # Improve times and remove errors. - _skip_slow() - # multiple Dataset self.ds = xr.Dataset() self.nt = 1000 @@ -298,6 +311,10 @@ def make_ds(self, nfiles=10): class IOWriteMultipleNetCDF3(IOMultipleNetCDF): def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() + self.make_ds() self.format = "NETCDF3_64BIT" @@ -314,6 +331,9 @@ def time_write_dataset_scipy(self): class IOReadMultipleNetCDF4(IOMultipleNetCDF): def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() requires_dask() @@ -330,6 +350,9 @@ def time_open_dataset_netcdf4(self): class IOReadMultipleNetCDF3(IOReadMultipleNetCDF4): def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() requires_dask() @@ -346,6 +369,9 @@ def time_open_dataset_scipy(self): class IOReadMultipleNetCDF4Dask(IOMultipleNetCDF): def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() requires_dask() @@ -400,6 +426,9 @@ def time_open_dataset_netcdf4_with_time_chunks_multiprocessing(self): class IOReadMultipleNetCDF3Dask(IOReadMultipleNetCDF4Dask): def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() requires_dask() @@ -435,10 +464,6 @@ def time_open_dataset_scipy_with_time_chunks(self): def create_delayed_write(): import dask.array as da - # TODO: Lazily skipped in CI as it is very demanding and slow. - # Improve times and remove errors. - _skip_slow() - vals = da.random.random(300, chunks=(1,)) ds = xr.Dataset({"vals": (["a"], vals)}) return ds.to_netcdf("file.nc", engine="netcdf4", compute=False) @@ -450,7 +475,12 @@ class IOWriteNetCDFDask: number = 5 def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() + requires_dask() + self.write = create_delayed_write() def time_write(self): @@ -459,15 +489,17 @@ def time_write(self): class IOWriteNetCDFDaskDistributed: def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() + + requires_dask() + try: import distributed except ImportError: raise NotImplementedError() - # TODO: Lazily skipped in CI as it is very demanding and slow. - # Improve times and remove errors. - _skip_slow() - self.client = distributed.Client() self.write = create_delayed_write() @@ -476,3 +508,145 @@ def cleanup(self): def time_write(self): self.write.compute() + + +class IOReadSingleFile(IOSingleNetCDF): + def setup(self, *args, **kwargs): + self.make_ds() + + self.filepaths = {} + for engine in _ENGINES: + self.filepaths[engine] = f"test_single_file_with_{engine}.nc" + self.ds.to_netcdf(self.filepaths[engine], engine=engine) + + @parameterized(["engine", "chunks"], (_ENGINES, [None, {}])) + def time_read_dataset(self, engine, chunks): + xr.open_dataset(self.filepaths[engine], engine=engine, chunks=chunks) + + +class IOReadCustomEngine: + def setup(self, *args, **kwargs): + """ + The custom backend does the bare mininum to be considered a lazy backend. But + the data in it is still in memory so slow file reading shouldn't affect the + results. + """ + requires_dask() + + @dataclass + class PerformanceBackendArray(xr.backends.BackendArray): + filename_or_obj: str | os.PathLike | None + shape: tuple[int, ...] + dtype: np.dtype + lock: xr.backends.locks.SerializableLock + + def __getitem__(self, key: tuple): + return xr.core.indexing.explicit_indexing_adapter( + key, + self.shape, + xr.core.indexing.IndexingSupport.BASIC, + self._raw_indexing_method, + ) + + def _raw_indexing_method(self, key: tuple): + raise NotImplementedError + + @dataclass + class PerformanceStore(xr.backends.common.AbstractWritableDataStore): + manager: xr.backends.CachingFileManager + mode: str | None = None + lock: xr.backends.locks.SerializableLock | None = None + autoclose: bool = False + + def __post_init__(self): + self.filename = self.manager._args[0] + + @classmethod + def open( + cls, + filename: str | os.PathLike | None, + mode: str = "r", + lock: xr.backends.locks.SerializableLock | None = None, + autoclose: bool = False, + ): + if lock is None: + if mode == "r": + locker = xr.backends.locks.SerializableLock() + else: + locker = xr.backends.locks.SerializableLock() + else: + locker = lock + + manager = xr.backends.CachingFileManager( + xr.backends.DummyFileManager, + filename, + mode=mode, + ) + return cls(manager, mode=mode, lock=locker, autoclose=autoclose) + + def load(self) -> tuple: + """ + Load a bunch of test data quickly. + + Normally this method would've opened a file and parsed it. + """ + n_variables = 2000 + + # Important to have a shape and dtype for lazy loading. + shape = (1,) + dtype = np.dtype(int) + variables = { + f"long_variable_name_{v}": xr.Variable( + data=PerformanceBackendArray( + self.filename, shape, dtype, self.lock + ), + dims=("time",), + fastpath=True, + ) + for v in range(0, n_variables) + } + attributes = {} + + return variables, attributes + + class PerformanceBackend(xr.backends.BackendEntrypoint): + def open_dataset( + self, + filename_or_obj: str | os.PathLike | None, + drop_variables: tuple[str] = None, + *, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + use_cftime=None, + decode_timedelta=None, + lock=None, + **kwargs, + ) -> xr.Dataset: + filename_or_obj = xr.backends.common._normalize_path(filename_or_obj) + store = PerformanceStore.open(filename_or_obj, lock=lock) + + store_entrypoint = xr.backends.store.StoreBackendEntrypoint() + + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds + + self.engine = PerformanceBackend + + @parameterized(["chunks"], ([None, {}])) + def time_open_dataset(self, chunks): + """ + Time how fast xr.open_dataset is without the slow data reading part. + Test with and without dask. + """ + xr.open_dataset(None, engine=self.engine, chunks=chunks)